8000 add simple test for search request by LHT129 · Pull Request #860 · antgroup/vsag · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

add simple test for search request #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter < 8000 /summary>

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/data_cell/attribute_bucket_inverted_datacell.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ AttributeBucketInvertedDataCell::get_bitsets_by_type(const ValueMapPtr& value_ma
if (attr_value == nullptr) {
throw VsagException(ErrorType::INTERNAL_ERROR, "Invalid attribute type");
}
bitsets.reserve(attr_value->GetValue().size());
for (auto& value : attr_value->GetValue()) {
auto bitset = value_map->GetBitsetByValue(value);
bitsets.emplace_back(bitset);
Expand Down
2 changes: 1 addition & 1 deletion src/simd/avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ FP32ComputeL2SqrBatch4(const float* RESTRICT query,
float& result2,
float& result3,
float& result4) {
#if defined(ENABLE_AVX2)
#if defined(ENABLE_AVX512)
if (dim < 16) {
return avx2::FP32ComputeL2SqrBatch4(
query, dim, codes1, codes2, codes3, codes4, result1, result2, result3, result4);
Expand Down
24 changes: 22 additions & 2 deletions tests/fixtures/fixtures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,21 @@ FillStringValues(vsag::AttributeValue<std::string>* attr,
}
}

static std::vector<int>
select_k_numbers(int64_t n, int k) {
std::vector<int> numbers(n);
std::iota(numbers.begin(), numbers.end(), 0);

std::random_device rd;
std::mt19937 gen(rd());
for (int i = 0; i < k; ++i) {
std::uniform_int_distribution<> dist(i, static_cast<int>(n - 1));
std::swap(numbers[i], numbers[dist(gen)]);
}
numbers.resize(k);
return numbers;
}

template <typename Gen>
vsag::Attribute*
CreateAttribute(std::string term_name,
Expand Down Expand Up @@ -316,14 +331,19 @@ generate_attributes(uint64_t count, uint32_t max_term_count, uint32_t max_value_
for (uint64_t i = 0; i < term_count; ++i) {
std::string term_name = fmt::format("term_{}", i);
auto term_type = static_cast<vsag::AttrValueType>(type_dist(gen));
if (term_type == vsag::AttrValueType::UINT64) {
term_type = vsag::AttrValueType::INT64;
}
terms[i] = {term_name, term_type};
}

for (uint32_t i = 0; i < count; ++i) {
auto cur_term_count = term_count_dist(gen) % term_count;
auto cur_term_count = RandomValue(1, term_count);
results[i].attrs_.reserve(cur_term_count);
auto idxes = select_k_numbers(term_count, cur_term_count);

for (uint32_t j = 0; j < cur_term_count; ++j) {
auto term_id = term_count_dist(gen) % term_count;
auto term_id = idxes[j];
auto& term_name = terms[term_id].first;
auto& term_type = terms[term_id].second;
auto attr = CreateAttribute(term_name, term_type, value_count_dist(gen), 10, gen);
Expand Down
10 changes: 9 additions & 1 deletion tests/test_hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,8 @@ TestHGraphBuildWithAttr(const fixtures::HGraphTestIndexPtr& test_index,
const fixtures::HGraphResourcePtr& resource) {
using namespace fixtures;
auto origin_size = vsag::Options::Instance().block_size_limit();
auto search_param = fmt::format(fixtures::search_param_tmp, 200, false);

auto size = GENERATE(1024 * 1024 * 2);

for (auto metric_type : resource->metric_types) {
Expand Down Expand Up @@ -639,8 +641,14 @@ TestHGraphBuildWithAttr(const fixtures::HGraphTestIndexPtr& test_index,
auto dataset = HGraphTestIndex::pool.GetDatasetAndCreate(
dim, resource->base_count, metric_type);

if (not index->CheckFeature(vsag::SUPPORT_BUILD)) {
continue;
}
auto build_result = index->Build(dataset->base_);
REQUIRE(build_result.has_value());

// Execute attribute-aware build test
TestIndex::TestBuildWithAttr(index, dataset);
// TestIndex::TestWithAttr(index, dataset, search_param);

// Restore original block size limit
vsag::Options::Instance().set_block_size_limit(origin_size);
Expand Down
111 changes: 106 additions & 5 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1690,13 +1690,114 @@ TestIndex::TestRemoveIndex(const TestIndex::IndexPtr& index,
REQUIRE(index->GetNumElements() == dataset->base_->GetNumElements());
}
}

template <class T>
std::string
create_attr_string(const std::string& name, const std::vector<T>& values) {
if (values.size() == 1) {
std::stringstream ss;
if constexpr (std::is_same_v<T, std::string>) {
ss << name << " = \"" << values[0] << "\"";
} else {
ss << name << " = " << std::to_string(values[0]);
}
return ss.str();
}
std::ostringstream oss;
for (size_t i = 0; i < values.size(); ++i) {
if (i != 0) {
oss << "|";
}
if constexpr (std::is_same_v<T, std::string>) {
oss << values[i];
} else {
oss << std::to_string(values[i]);
}
}
return "multi_in(" + name + ", \"" + oss.str() + "\", \"|\")";
}

std::string
trans_attr_to_string(const vsag::Attribute& attr) {
using namespace vsag;
auto name = attr.name_;
auto type = attr.GetValueType();
if (type == AttrValueType::STRING) {
const auto temp = dynamic_cast<const AttributeValue<std::string>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::UINT8) {
const auto temp = dynamic_cast<const AttributeValue<uint8_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::UINT16) {
const auto temp = dynamic_cast<const AttributeValue<uint16_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::UINT32) {
const auto temp = dynamic_cast<const AttributeValue<uint32_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::UINT64) {
const auto temp = dynamic_cast<const AttributeValue<uint64_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::INT8) {
const auto temp = dynamic_cast<const AttributeValue<int8_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::INT16) {
const auto temp = dynamic_cast<const AttributeValue<int16_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::INT32) {
const auto temp = dynamic_cast<const AttributeValue<int32_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
} else if (type == AttrValueType::INT64) {
const auto temp = dynamic_cast<const AttributeValue<int64_t>*>(&attr);
auto values = temp->GetValue();
return create_attr_string(name, values);
}
return "";
}

void
TestIndex::TestBuildWithAttr(const IndexPtr& index, const TestDatasetPtr& dataset) {
if (not index->CheckFeature(vsag::SUPPORT_BUILD)) {
return;
TestIndex::TestWithAttr(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param) {
auto attrsets = dataset->base_->GetAttributeSets();
auto* query_vec = dataset->base_->GetFloat32Vectors();
auto base_count = dataset->base_->GetNumElements();
auto dim = dataset->base_->GetDim();
std::vector<vsag::SearchRequest> reqs(base_count);
for (int i = 0; i < base_count; ++i) {
auto query = vsag::Dataset::Make();
query->Float32Vectors(query_vec + i * dim)->Dim(dim)->Owner(false)->NumElements(1);
auto attrset = attrsets[i].attrs_;
int j = random() % attrset.size();
auto attr = attrset[j];
int j2 = random() % attrset.size();
INFO(std::to_string(i));
vsag::SearchRequest& req = reqs[i];
req.topk_ = 10;
req.filter_ = nullptr;
req.params_str_ = search_param;
req.enable_attribute_filter_ = true;
req.query_ = query;
req.attribute_filter_str_ = "(" + trans_attr_to_string(*attrset[j2]) + ") AND (" +
trans_attr_to_string(*attr) + ")";
}

for (int i = 0; i < base_count; ++i) {
auto the_id = dataset->base_->GetIds()[i];
auto result = index->SearchWithRequest(reqs[i]);
REQUIRE(result.has_value());
auto ids = result.value()->GetIds();
auto result_count = result.value()->GetNumElements();
std::unordered_set<int64_t> sets(ids, ids + result_count);
REQUIRE(sets.find(the_id) != sets.end());
}
auto build_result = index->Build(dataset->base_);
REQUIRE(build_result.has_value());
}

} // namespace fixtures
4 changes: 3 additions & 1 deletion tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ class TestIndex {
const std::string& search_param);

static void
TestBuildWithAttr(const IndexPtr& index, const TestDatasetPtr& dataset);
TestWithAttr(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param);

constexpr static float RECALL_THRESHOLD = 0.95;
};
Expand Down
32 changes: 26 additions & 6 deletions tests/test_ivf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,8 @@ TEST_CASE("[Daily] IVF Build With Large K", "[ft][ivf][daily]") {
}

static void
TestIVFBuildWithAttr(const fixtures::IVFTestIndexPtr& test_index,
const fixtures::IVFResourcePtr& resource) {
TestIVFWithAttr(const fixtures::IVFTestIndexPtr& test_index,
const fixtures::IVFResourcePtr& resource) {
using namespace fixtures;
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down Expand Up @@ -616,10 +616,30 @@ TestIVFBuildWithAttr(const fixtures::IVFTestIndexPtr& test_index,
false,
1,
use_attribute_filter);
auto index = IVFTestIndex::TestFactory(IVFTestIndex::name, param, true);
auto index1 = IVFTestIndex::TestFactory(IVFTestIndex::name, param, true);
auto dataset = IVFTestIndex::pool.GetDatasetAndCreate(
dim, resource->base_count, metric_type);
IVFTestIndex::TestBuildWithAttr(index, dataset);
if (not index1->CheckFeature(vsag::SUPPORT_BUILD)) {
continue;
}
auto build_result = index1->Build(dataset->base_);
REQUIRE(build_result.has_value());
IVFTestIndex::TestWithAttr(index1, dataset, search_param);

auto dir = fixtures::TempDir("serialize");
auto path = dir.GenerateRandomFile();
std::ofstream outfile(path, std::ios::out | std::ios::binary);
auto serialize_index = index1->Serialize(outfile);
REQUIRE(serialize_index.has_value());
outfile.close();

auto index = TestIndex::TestFactory(IVFTestIndex::name, param, true);
std::ifstream infile(path, std::ios::in | std::ios::binary);
auto deserialize_index = index->Deserialize(infile);
REQUIRE(deserialize_index.has_value());
infile.close();
IVFTestIndex::TestWithAttr(index, dataset, search_param);

vsag::Options::Instance().set_block_size_limit(origin_size);
}
}
Expand All @@ -630,13 +650,13 @@ TestIVFBuildWithAttr(const fixtures::IVFTestIndexPtr& test_index,
TEST_CASE("[PR] IVF Build With Attribute", "[ft][ivf][pr]") {
auto test_index = std::make_shared<fixtures::IVFTestIndex>();
auto resource = test_index->GetResource(true);
TestIVFBuildWithAttr(test_index, resource);
TestIVFWithAttr(test_index, resource);
}

TEST_CASE("[Daily] IVF Build With Attribute", "[ft][ivf][daily]") {
auto test_index = std::make_shared<fixtures::IVFTestIndex>();
auto resource = test_index->GetResource(false);
TestIVFBuildWithAttr(test_index, resource);
TestIVFWithAttr(test_index, resource);
}

static void
Expand Down
Loading
0