From e0a7e719286fabb64b0d2ac49a8cc5f2afa99425 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 24 Jun 2025 13:59:34 +0800 Subject: [PATCH] add simple test for search request Signed-off-by: LHT129 --- src/simd/avx512.cpp | 2 +- tests/fixtures/fixtures.cpp | 24 +++++++- tests/test_hgraph.cpp | 10 +++- tests/test_index.cpp | 111 ++++++++++++++++++++++++++++++++++-- tests/test_index.h | 4 +- tests/test_ivf.cpp | 32 +++++++++-- 6 files changed, 167 insertions(+), 16 deletions(-) diff --git a/src/simd/avx512.cpp b/src/simd/avx512.cpp index 637870ae4..b94947f21 100644 --- a/src/simd/avx512.cpp +++ b/src/simd/avx512.cpp @@ -261,7 +261,7 @@ FP32ComputeL2SqrBatch4(const float* RESTRICT query, float& result2, float& result3, float& result4) { -#if defined(ENABLE_AVX2) +#if defined(ENABLE_AVX512) if (dim < 16) { return avx2::FP32ComputeL2SqrBatch4( query, dim, codes1, codes2, codes3, codes4, result1, result2, result3, result4); diff --git a/tests/fixtures/fixtures.cpp b/tests/fixtures/fixtures.cpp index 2b2aba7f9..70ddbe7bd 100644 --- a/tests/fixtures/fixtures.cpp +++ b/tests/fixtures/fixtures.cpp @@ -235,6 +235,21 @@ FillStringValues(vsag::AttributeValue* attr, } } +static std::vector +select_k_numbers(int64_t n, int k) { + std::vector numbers(n); + std::iota(numbers.begin(), numbers.end(), 0); + + std::random_device rd; + std::mt19937 gen(rd()); + for (int i = 0; i < k; ++i) { + std::uniform_int_distribution<> dist(i, static_cast(n - 1)); + std::swap(numbers[i], numbers[dist(gen)]); + } + numbers.resize(k); + return numbers; +} + template vsag::Attribute* CreateAttribute(std::string term_name, @@ -316,14 +331,19 @@ generate_attributes(uint64_t count, uint32_t max_term_count, uint32_t max_value_ for (uint64_t i = 0; i < term_count; ++i) { std::string term_name = fmt::format("term_{}", i); auto term_type = static_cast(type_dist(gen)); + if (term_type == vsag::AttrValueType::UINT64) { + term_type = vsag::AttrValueType::INT64; + } terms[i] = {term_name, term_type}; } for (uint32_t i = 0; i < count; ++i) { - auto cur_term_count = term_count_dist(gen) % term_count; + auto cur_term_count = RandomValue(1, term_count); results[i].attrs_.reserve(cur_term_count); + auto idxes = select_k_numbers(term_count, cur_term_count); + for (uint32_t j = 0; j < cur_term_count; ++j) { - auto term_id = term_count_dist(gen) % term_count; + auto term_id = idxes[j]; auto& term_name = terms[term_id].first; auto& term_type = terms[term_id].second; auto attr = CreateAttribute(term_name, term_type, value_count_dist(gen), 10, gen); diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index a597a472c..1d1636358 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -623,6 +623,8 @@ TestHGraphBuildWithAttr(const fixtures::HGraphTestIndexPtr& test_index, const fixtures::HGraphResourcePtr& resource) { using namespace fixtures; auto origin_size = vsag::Options::Instance().block_size_limit(); + auto search_param = fmt::format(fixtures::search_param_tmp, 200, false); + auto size = GENERATE(1024 * 1024 * 2); for (auto metric_type : resource->metric_types) { @@ -659,8 +661,14 @@ TestHGraphBuildWithAttr(const fixtures::HGraphTestIndexPtr& test_index, auto dataset = HGraphTestIndex::pool.GetDatasetAndCreate( dim, resource->base_count, metric_type); + if (not index->CheckFeature(vsag::SUPPORT_BUILD)) { + continue; + } + auto build_result = index->Build(dataset->base_); + REQUIRE(build_result.has_value()); + // Execute attribute-aware build test - TestIndex::TestBuildWithAttr(index, dataset); + // TestIndex::TestWithAttr(index, dataset, search_param); // Restore original block size limit vsag::Options::Instance().set_block_size_limit(origin_size); diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 28ea61e3c..3d22c3ed6 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -1690,13 +1690,114 @@ TestIndex::TestRemoveIndex(const TestIndex::IndexPtr& index, REQUIRE(index->GetNumElements() == dataset->base_->GetNumElements()); } } + +template +std::string +create_attr_string(const std::string& name, const std::vector& values) { + if (values.size() == 1) { + std::stringstream ss; + if constexpr (std::is_same_v) { + ss << name << " = \"" << values[0] << "\""; + } else { + ss << name << " = " << std::to_string(values[0]); + } + return ss.str(); + } + std::ostringstream oss; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) { + oss << "|"; + } + if constexpr (std::is_same_v) { + oss << values[i]; + } else { + oss << std::to_string(values[i]); + } + } + return "multi_in(" + name + ", \"" + oss.str() + "\", \"|\")"; +} + +std::string +trans_attr_to_string(const vsag::Attribute& attr) { + using namespace vsag; + auto name = attr.name_; + auto type = attr.GetValueType(); + if (type == AttrValueType::STRING) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::UINT8) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::UINT16) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::UINT32) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::UINT64) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::INT8) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::INT16) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::INT32) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } else if (type == AttrValueType::INT64) { + const auto temp = dynamic_cast*>(&attr); + auto values = temp->GetValue(); + return create_attr_string(name, values); + } + return ""; +} + void -TestIndex::TestBuildWithAttr(const IndexPtr& index, const TestDatasetPtr& dataset) { - if (not index->CheckFeature(vsag::SUPPORT_BUILD)) { - return; +TestIndex::TestWithAttr(const IndexPtr& index, + const TestDatasetPtr& dataset, + const std::string& search_param) { + auto attrsets = dataset->base_->GetAttributeSets(); + auto* query_vec = dataset->base_->GetFloat32Vectors(); + auto base_count = dataset->base_->GetNumElements(); + auto dim = dataset->base_->GetDim(); + std::vector reqs(base_count); + for (int i = 0; i < base_count; ++i) { + auto query = vsag::Dataset::Make(); + query->Float32Vectors(query_vec + i * dim)->Dim(dim)->Owner(false)->NumElements(1); + auto attrset = attrsets[i].attrs_; + int j = random() % attrset.size(); + auto attr = attrset[j]; + int j2 = random() % attrset.size(); + INFO(std::to_string(i)); + vsag::SearchRequest& req = reqs[i]; + req.topk_ = 10; + req.filter_ = nullptr; + req.params_str_ = search_param; + req.enable_attribute_filter_ = true; + req.query_ = query; + req.attribute_filter_str_ = "(" + trans_attr_to_string(*attrset[j2]) + ") AND (" + + trans_attr_to_string(*attr) + ")"; + } + + for (int i = 0; i < base_count; ++i) { + auto the_id = dataset->base_->GetIds()[i]; + auto result = index->SearchWithRequest(reqs[i]); + REQUIRE(result.has_value()); + auto ids = result.value()->GetIds(); + auto result_count = result.value()->GetNumElements(); + std::unordered_set sets(ids, ids + result_count); + REQUIRE(sets.find(the_id) != sets.end()); } - auto build_result = index->Build(dataset->base_); - REQUIRE(build_result.has_value()); } void TestIndex::TestGetRawVectorByIds(const IndexPtr& index, diff --git a/tests/test_index.h b/tests/test_index.h index b04d377e6..4e8f389c7 100644 --- a/tests/test_index.h +++ b/tests/test_index.h @@ -256,7 +256,9 @@ class TestIndex { const std::string& search_param); static void - TestBuildWithAttr(const IndexPtr& index, const TestDatasetPtr& dataset); + TestWithAttr(const IndexPtr& index, + const TestDatasetPtr& dataset, + const std::string& search_param); constexpr static float RECALL_THRESHOLD = 0.95; }; diff --git a/tests/test_ivf.cpp b/tests/test_ivf.cpp index fbe0a7d20..3be02ff4c 100644 --- a/tests/test_ivf.cpp +++ b/tests/test_ivf.cpp @@ -594,8 +594,8 @@ TEST_CASE("[Daily] IVF Build With Large K", "[ft][ivf][daily]") { } static void -TestIVFBuildWithAttr(const fixtures::IVFTestIndexPtr& test_index, - const fixtures::IVFResourcePtr& resource) { +TestIVFWithAttr(const fixtures::IVFTestIndexPtr& test_index, + const fixtures::IVFResourcePtr& resource) { using namespace fixtures; auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); @@ -628,10 +628,30 @@ TestIVFBuildWithAttr(const fixtures::IVFTestIndexPtr& test_index, false, 1, use_attribute_filter); - auto index = IVFTestIndex::TestFactory(IVFTestIndex::name, param, true); + auto index1 = IVFTestIndex::TestFactory(IVFTestIndex::name, param, true); auto dataset = IVFTestIndex::pool.GetDatasetAndCreate( dim, resource->base_count, metric_type); - IVFTestIndex::TestBuildWithAttr(index, dataset); + if (not index1->CheckFeature(vsag::SUPPORT_BUILD)) { + continue; + } + auto build_result = index1->Build(dataset->base_); + REQUIRE(build_result.has_value()); + IVFTestIndex::TestWithAttr(index1, dataset, search_param); + + auto dir = fixtures::TempDir("serialize"); + auto path = dir.GenerateRandomFile(); + std::ofstream outfile(path, std::ios::out | std::ios::binary); + auto serialize_index = index1->Serialize(outfile); + REQUIRE(serialize_index.has_value()); + outfile.close(); + + auto index = TestIndex::TestFactory(IVFTestIndex::name, param, true); + std::ifstream infile(path, std::ios::in | std::ios::binary); + auto deserialize_index = index->Deserialize(infile); + REQUIRE(deserialize_index.has_value()); + infile.close(); + IVFTestIndex::TestWithAttr(index, dataset, search_param); + vsag::Options::Instance().set_block_size_limit(origin_size); } } @@ -642,13 +662,13 @@ TestIVFBuildWithAttr(const fixtures::IVFTestIndexPtr& test_index, TEST_CASE("[PR] IVF Build With Attribute", "[ft][ivf][pr]") { auto test_index = std::make_shared(); auto resource = test_index->GetResource(true); - TestIVFBuildWithAttr(test_index, resource); + TestIVFWithAttr(test_index, resource); } TEST_CASE("[Daily] IVF Build With Attribute", "[ft][ivf][daily]") { auto test_index = std::make_shared(); auto resource = test_index->GetResource(false); - TestIVFBuildWithAttr(test_index, resource); + TestIVFWithAttr(test_index, resource); } static void