8000 Introduce mrq by ShawnShawnYou · Pull Request #872 · antgroup/vsag · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Introduce mrq #872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/impl/transform/pca_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ PCATransformer::Transform(const float* input_vec, float* output_vec) const {
this->CentralizeData(input_vec, centralized_vec.data());

// output_vec[i] = sum_j(input_vec[j] * pca_matrix_[j, i])
// e.g., original_dim == 3, target_dim == 2
// e.g., input_dim == 3, output_dim == 2
// [1, 0, 0,] * [1,] = [1,]
// [0, 0, 1 ] [2,] = [3 ]
// [3 ]
Expand Down Expand Up @@ -193,7 +193,7 @@ PCATransformer::SetMeanForTest(const float* input_mean) {
}

void
PCATransformer::SetPCAMatrixForText(const float* input_pca_matrix) {
PCATransformer::SetPCAMatrixForTest(const float* input_pca_matrix) {
for (uint64_t i = 0; i < pca_matrix_.size(); i++) {
pca_matrix_[i] = input_pca_matrix[i];
}
Expand Down
3 changes: 1 addition & 2 deletions src/impl/transform/pca_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

namespace vsag {

// aka PCA
class PCATransformer : public VectorTransformer {
public:
// interface
Expand Down Expand Up @@ -57,7 +56,7 @@ class PCATransformer : public VectorTransformer {
SetMeanForTest(const float* input_mean);

void
SetPCAMatrixForText(const float* input_pca_matrix);
SetPCAMatrixForTest(const float* input_pca_matrix);

void
ComputeColumnMean(const float* data, uint64_t count);
Expand Down
2 changes: 1 addition & 1 deletion src/impl/transform/pca_transformer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ TestTransform() {
1.0f}; // eigen_vec[-2]
PCATransformer pca(allocator.get(), original_dim, target_dim);
pca.SetMeanForTest(mean.data());
pca.SetPCAMatrixForText(pca_matrix.data());
pca.SetPCAMatrixForTest(pca_matrix.data());

std::vector<float> input = {4.0f, 6.0f, 8.0f}; // centralized: [1, 2, 3]
std::vector<float> output(target_dim, 0);
Expand Down
78 changes: 76 additions & 2 deletions src/quantization/quantizer_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <fstream>

#include "fixtures.h"
#include "fp32_quantizer.h"
#include "iostream"
#include "quantizer.h"
#include "simd/normalize.h"
Expand Down Expand Up @@ -46,20 +47,27 @@ TestEncodeDecodeRaBitQ(Quantizer<T>& quantizer,

// Test EncodeOne & DecodeOne
float count_same_sign_1 = 0;
bool decode_result;
std::vector<uint8_t> codes1(quantizer.GetCodeSize() * count);
for (uint64_t i = 0; i < count; ++i) {
uint8_t* codes = codes1.data() + i * quantizer.GetCodeSize();
quantizer.EncodeOne(vecs.data() + i * dim, codes);

std::vector<float> out_vec(dim);
quantizer.DecodeOne(codes, out_vec.data());
decode_result = quantizer.DecodeOne(codes, out_vec.data());
if (not decode_result) {
continue;
}
for (uint64_t d = 0; d < dim; ++d) {
if (vecs[i * dim + d] * out_vec[d] >= 0) {
count_same_sign_1++;
}
}
}
REQUIRE(count_same_sign_1 / (count * dim) > same_sign_rate);

if (decode_result) {
REQUIRE(count_same_sign_1 / (count * dim) > same_sign_rate);
}

// Test EncodeBatch & DecodeBatch
float count_same_sign_2 = 0;
Expand All @@ -69,6 +77,10 @@ TestEncodeDecodeRaBitQ(Quantizer<T>& quantizer,
REQUIRE(codes1[c] == codes2[c]);
}

if (not decode_result) {
return;
}

std::vector<float> out_vec(dim * count);
quantizer.DecodeBatch(codes2.data(), out_vec.data(), count);
for (int64_t i = 0; i < dim * count; ++i) {
Expand Down Expand Up @@ -223,6 +235,68 @@ TestComputeCodesSame(Quantizer<T>& quantizer,
}
}

template <typename T>
std::vector<std::vector<float>>
ComputeAllDists(Quantizer<T>& quantizer, std::vector<float> data, uint32_t count, uint32_t dim) {
std::vector<uint8_t> codes(quantizer.GetCodeSize() * count);
std::vector<std::vector<float>> dists(count, std::vector<float>(count, 0));

// 1. encode
quantizer.EncodeBatch(data.data(), codes.data(), count);

// 2. compute dist
for (int i = 0; i < count; i++) {
std::shared_ptr<Computer<T>> computer;
computer = quantizer.FactoryComputer();
computer->SetQuery(data.data() + i * dim);

for (int j = 0; j < count; j++) {
auto code = codes.data() + j * quantizer.GetCodeSize();
if (i == j) {
continue;
}
dists[i][j] = quantizer.ComputeDist(*computer, code);
}
}
return dists;
}

template <typename T, MetricType metric>
void
TestInversePair(Quantizer<T>& quantizer, size_t dim, uint32_t count, Allocator* allocator) {
auto logger = vsag::Options::Instance().logger();
count = std::min(count, (uint32_t)100);
auto data = fixtures::generate_vectors(count, dim, false);
FP32Quantizer<metric> fp32_quantizer(dim, allocator);
fp32_quantizer.ReTrain(data.data(), count);
quantizer.ReTrain(data.data(), count);

auto dist_fp32 = ComputeAllDists<FP32Quantizer<metric>>(fp32_quantizer, data, count, dim);
auto dist_quan = ComputeAllDists<T>(quantizer, data, count, dim);

uint32_t count_diff = 0;
uint32_t count_compare = 0;
for (int i = 0; i < count; i++) {
for (int j = 0; j < count; j++) {
for (int k = j + 1; k < count; k++) {
count_compare++;
if (dist_fp32[i][j] < dist_fp32[i][k]) {
if (dist_quan[i][j] > dist_quan[i][k]) {
count_diff++;
}
} else {
if (dist_quan[i][j] < dist_quan[i][k]) {
count_diff++;
}
}
}
}
}

logger::debug(fmt::format("count_diff: {}, count_compare: {}", count_diff, count_compare));
REQUIRE(1.0 * count_diff / count_compare < 0.5);
}

template <typename T, MetricType metric>
void
TestComputer(Quantizer<T>& quant,
Expand Down
71 changes: 59 additions & 12 deletions src/quantization/rabitq_quantization/rabitq_quantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@
#include "quantization/quantizer.h"
#include "quantization/scalar_quantization/sq4_uniform_quantizer.h"
#include "rabitq_quantizer_parameter.h"
#include "simd/fp32_simd.h"
#include "simd/normalize.h"
#include "simd/rabitq_simd.h"
#include "typing.h"
#include "utils/util_functions.h"

namespace vsag {

/** Implement of RaBitQ Quantization
/** Implement of RaBitQ Quantization, Integrate MRQ (Minimized Residual Quantization)
*
* Supports bit-level quantization
* RaBitQ: Supports bit-level quantization
* MRQ: Support use residual part of PCA to increase precision
*
* Reference:
* Jianyang Gao and Cheng Long. 2024. RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search. Proc. ACM Manag. Data 2, 3, Article 167 (June 2024), 27 pages. https://doi.org/10.1145/3654970
* [1] Jianyang Gao and Cheng Long. 2024. RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search. Proc. ACM Manag. Data 2, 3, Article 167 (June 2024), 27 pages. https://doi.org/10.1145/3654970< 9E81 /span>
* [2] Mingyu Yang, Wentao Li, Wei Wang. Fast High-dimensional Approximate Nearest Neighbor Search with Efficient Index Time and Space
*/
template <MetricType metric = MetricType::METRIC_TYPE_L2SQR>
class RaBitQuantizer : public Quantizer<RaBitQuantizer<metric>> {
Expand All @@ -51,6 +54,7 @@ class RaBitQuantizer : public Quantizer<RaBitQuantizer<metric>> {
uint64_t pca_dim,
uint64_t num_bits_per_dim_query,
bool use_fht,
bool use_mrq,
Allocator* allocator);

explicit RaBitQuantizer(const RaBitQuantizerParamPtr& param,
Expand Down Expand Up @@ -164,9 +168,10 @@ class RaBitQuantizer : public Quantizer<RaBitQuantizer<metric>> {
std::shared_ptr<PCATransformer> pca_;
std::uint64_t original_dim_{0};
std::uint64_t pca_dim_{0};
bool use_mrq_{false};

/***
* query layout: sq-code(required) + lower_bound(sq4) + delta(sq4) + sum(sq4) + norm(required)
* query layout: sq-code(required) + lower_bound(sq4) + delta(sq4) + sum(sq4) + norm(required) + mrq_norm(required)
*/
uint64_t aligned_dim_{0};
uint64_t num_bits_per_dim_query_{32};
Expand All @@ -175,27 +180,38 @@ class RaBitQuantizer : public Quantizer<RaBitQuantizer<metric>> {
uint64_t query_offset_delta_{0};
uint64_t query_offset_sum_{0};
uint64_t query_offset_norm_{0};
uint64_t query_offset_mrq_norm_{0};

/***
* code layout: bq-code(required) + norm(required) + error(required) + sum(sq4)
* code layout: bq-code(required) + norm(required) + error(required) + sum(sq4) + mrq_norm(required)
*/
uint64_t offset_code_{0};
uint64_t offset_norm_{0};
uint64_t offset_error_{0};
uint64_t offset_sum_{0};
uint64_t offset_mrq_norm_{0};
};

template <MetricType metric>
RaBitQuantizer<metric>::RaBitQuantizer(
int dim, uint64_t pca_dim, uint64_t num_bits_per_dim_query, bool use_fht, Allocator* allocator)
RaBitQuantizer<metric>::RaBitQuantizer(int dim,
uint64_t pca_dim,
uint64_t num_bits_per_dim_query,
bool use_fht,
bool use_mrq,
Allocator* allocator)
: Quantizer<RaBitQuantizer<metric>>(dim, allocator) {
static_assert(metric == MetricType::METRIC_TYPE_L2SQR, "Unsupported metric type");

// dim
use_mrq_ = use_mrq;
pca_dim_ = pca_dim;
original_dim_ = dim;
if (0 < pca_dim_ and pca_dim_ < dim) {
pca_.reset(new PCATransformer(allocator, dim, pca_dim_));
if (use_mrq_) {
pca_.reset(new PCATransformer(allocator, dim, dim));
} else {
pca_.reset(new PCATransformer(allocator, dim, pca_dim_));
}
this->dim_ = pca_dim_;
} else {
pca_dim_ = dim;
Expand Down Expand Up @@ -262,6 +278,15 @@ RaBitQuantizer<metric>::RaBitQuantizer(

query_offset_norm_ = this->query_code_size_;
this->query_code_size_ += ((sizeof(norm_type) + align_size - 1) / align_size) * align_size;

// MRQ residual term
if (pca_dim_ != original_dim_ and use_mrq_) {
offset_mrq_norm_ = this->code_size_;
this->code_size_ += ((sizeof(norm_type) + align_size - 1) / align_size) * align_size;

query_offset_mrq_norm_ = this->query_code_size_;
this->query_code_size_ += ((sizeof(norm_type) + align_size - 1) / align_size) * align_size;
}
}

template <MetricType metric>
Expand All @@ -271,6 +296,7 @@ RaBitQuantizer<metric>::RaBitQuantizer(const RaBitQuantizerParamPtr& param,
param->pca_dim_,
param->num_bits_per_dim_query_,
param->use_fht_,
false,
common_param.allocator_.get()){};

template <MetricType metric>
Expand Down Expand Up @@ -300,7 +326,7 @@ RaBitQuantizer<metric>::TrainImpl(const DataType* data, uint64_t count) {
centroid_[d] = 0;
}
for (uint64_t i = 0; i < count; ++i) {
Vector<DataType> pca_data(this->dim_, 0, this->allocator_);
Vector<DataType> pca_data(this->original_dim_, 0, this->allocator_);
if (pca_dim_ != this->original_dim_) {
pca_->Transform(data + i * original_dim_, pca_data.data());
} else {
Expand Down Expand Up @@ -330,13 +356,20 @@ template <MetricType metric>
bool
RaBitQuantizer<metric>::EncodeOneImpl(const DataType* data, uint8_t* codes) const {
// 0. init
Vector<DataType> pca_data(this->dim_, 0, this->allocator_);
Vector<DataType> pca_data(this->original_dim_, 0, this->allocator_);
Vector<DataType> transformed_data(this->dim_, 0, this->allocator_);
Vector<DataType> normed_data(this->dim_, 0, this->allocator_);
memset(codes, 0, this->code_size_);

// 1. pca
if (pca_dim_ != this->original_dim_) {
pca_->Transform(data, pca_data.data());
if (use_mrq_) {
norm_type mrq_norm_sqr = FP32ComputeIP(pca_data.data() + this->dim_,
pca_data.data() + this->dim_,
this->original_dim_ - this->dim_);
*(norm_type*)(codes + offset_mrq_norm_) = mrq_norm_sqr;
}
} else {
pca_data.assign(data, data + original_dim_);
}
Expand All @@ -350,7 +383,6 @@ RaBitQuantizer<metric>::EncodeOneImpl(const DataType* data, uint8_t* codes) cons

// 4. encode with BQ
sum_type sum = 0;
memset(codes, 0, this->code_size_);
for (uint64_t d = 0; d < this->dim_; ++d) {
if (normed_data[d] >= 0.0f) {
sum += 1;
Expand Down Expand Up @@ -464,6 +496,13 @@ RaBitQuantizer<metric>::ComputeQueryBaseImpl(const uint8_t* query_codes,

float result = L2_UBE(base_norm, query_norm, ip_est);

if (pca_dim_ != this->original_dim_ and use_mrq_) {
norm_type query_mrq_norm_sqr = *(norm_type*)(query_codes + query_offset_mrq_norm_);
norm_type base_mrq_norm_sqr = *(norm_type*)(base_codes + offset_mrq_norm_);

result += (query_mrq_norm_sqr + base_mrq_norm_sqr);
}

return result;
}

Expand Down Expand Up @@ -606,13 +645,21 @@ RaBitQuantizer<metric>::ProcessQueryImpl(const DataType* query,
computer.buf_ = reinterpret_cast<uint8_t*>(this->allocator_->Allocate(query_code_size_));
std::fill(computer.buf_, computer.buf_ + query_code_size_, 0);

Vector<DataType> pca_data(this->dim_, 0, this->allocator_);
// use residual term in pca, so it's this->original_dim_
Vector<DataType> pca_data(this->original_dim_, 0, this->allocator_);
Vector<DataType> transformed_data(this->dim_, 0, this->allocator_);
Vector<DataType> normed_data(this->dim_, 0, this->allocator_);

// 1. pca
if (pca_dim_ != this->original_dim_) {
pca_->Transform(query, pca_data.data());
if (use_mrq_) {
norm_type mrq_norm_sqr = FP32ComputeIP(pca_data.data() + this->dim_,
pca_data.data() + this->dim_,
this->original_dim_ - this->dim_);

*(norm_type*)(computer.buf_ + query_offset_mrq_norm_) = mrq_norm_sqr;
}
} else {
pca_data.assign(query, query + original_dim_);
}
Expand Down
Loading
0