8000 Add more support to sparse vector. by small-turtle-1 · Pull Request #1254 · infiniflow/infinity · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add more support to sparse vector. #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ export namespace std {
using std::construct_at;

using std::set;

using std::all_of;
using std::any_of;
using std::none_of;
} // namespace std

namespace infinity {
Expand Down
121 changes: 94 additions & 27 deletions src/executor/operator/physical_import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,41 +752,97 @@ SharedPtr<ConstantExpr> BuildConstantExprFromJson(const nlohmann::json &json_obj
}
}
}
default: {
const auto error_info = fmt::format("Unrecognized json object type: {}", json_object.type_name());
LOG_ERROR(error_info);
RecoverableError(Status::ImportFileFormatError(error_info));
return nullptr;
}
}
}

SharedPtr<ConstantExpr> BuildConstantSparseExprFromJson(const nlohmann::json &json_object, const SparseInfo *sparse_info) {
SharedPtr<ConstantExpr> res = nullptr;
switch (sparse_info->DataType()) {
case kElemBit:
case kElemInt8:
case kElemInt16:
case kElemInt32:
case kElemInt64: {
res = MakeShared<ConstantExpr>(LiteralType::kLongSparseArray);
break;
}
case kElemFloat:
case kElemDouble: {
res = MakeShared<ConstantExpr>(LiteralType::kDoubleSparseArray);
break;
}
default: {
const auto error_info = fmt::format("Unsupported sparse data type: {}", sparse_info->DataType());
RecoverableError(Status::ImportFileFormatError(error_info));
return nullptr;
}
}
if (json_object.size() == 0) {
return res;
}
switch (json_object.type()) {
case nlohmann::json::value_t::array: {
const u32 array_size = json_object.size();
switch (json_object[0].type()) {
case nlohmann::json::value_t::number_unsigned:
case nlohmann::json::value_t::number_integer: {
res->long_sparse_array_.first.resize(array_size);
for (u32 i = 0; i < array_size; ++i) {
res->long_sparse_array_.first[i] = json_object[i].get<i64>();
}
return res;
}
default: {
const auto error_info = fmt::format("Unrecognized json object type in array: {}", json_object.type_name());
RecoverableError(Status::ImportFileFormatError(error_info));
return nullptr;
}
}
}
case nlohmann::json::value_t::object: {
SharedPtr<ConstantExpr> res = nullptr;
HashSet<i64> key_set;
for (auto iter = json_object.begin(); iter != json_object.end(); ++iter) {
i64 key = std::stoll(iter.key());
const auto &value_obj = iter.value();
switch(value_obj.type()) {
case nlohmann::json::value_t::number_unsigned:
case nlohmann::json::value_t::number_integer: {
if (res.get() == nullptr) {
res = MakeShared<ConstantExpr>(LiteralType::kLongSparseArray);
} else if (res->literal_type_ != LiteralType::kLongSparseArray) {
const auto error_info = "Invalid json object type in sparse array!";
auto [_, insert_ok] = key_set.insert(key);
if (!insert_ok) {
const auto error_info = fmt::format("Duplicate key {} in sparse array!", key);
RecoverableError(Status::ImportFileFormatError(error_info));
return nullptr;
}
if (res->literal_type_ == LiteralType::kLongSparseArray) {
const auto &value_obj = iter.value();
switch(value_obj.type()) {
case nlohmann::json::value_t::number_unsigned:
case nlohmann::json::value_t::number_integer: {
res->long_sparse_array_.first.push_back(key);
res->long_sparse_array_.second.push_back(value_obj.get<i64>());
break;
}
default: {
const auto error_info = fmt::format("Unrecognized json object type in array: {}", json_object.type_name());
RecoverableError(Status::ImportFileFormatError(error_info));
return nullptr;
}
res->long_sparse_array_.first.push_back(key);
res->long_sparse_array_.second.push_back(value_obj.get<i64>());
break;
}
case nlohmann::json::value_t::number_float: {
if (res.get() == nullptr) {
res = MakeShared<ConstantExpr>(LiteralType::kDoubleSparseArray);
} else if (res->literal_type_ != LiteralType::kDoubleSparseArray) {
const auto error_info = "Invalid json object type in sparse array!";
} else {
const auto &value_obj = iter.value();
switch(value_obj.type()) {
case nlohmann::json::value_t::number_float: {
res->double_sparse_array_.first.push_back(key);
res->double_sparse_array_.second.push_back(value_obj.get<double>());
break;
}
default: {
const auto error_info = fmt::format("Unrecognized json object type in array: {}", json_object.type_name());
RecoverableError(Status::ImportFileFormatError(error_info));
return nullptr;
}
res->double_sparse_array_.first.push_back(key);
res->double_sparse_array_.second.push_back(value_obj.get<double>());
break;
}
default: {
const auto error_info = fmt::format("Unrecognized json object type in array: {}", json_object.type_name());
RecoverableError(Status::ImportFileFormatError(error_info));
return nullptr;
}
}
}
Expand Down Expand Up @@ -888,10 +944,21 @@ void PhysicalImport::JSONLRowHandler(const nlohmann::json &line_json, Vector<Col
break;
}
case kTensor:
case kTensorArray:
case kSparse: {
case kTensorArray: {
// build ConstantExpr
SharedPtr<ConstantExpr> const_expr = BuildConstantExprFromJson(line_json[column_def->name_]);
if (const_expr.get() == nullptr) {
RecoverableError(Status::ImportFileFormatError("Invalid json object."));
}
column_vector.AppendByConstantExpr(const_expr.get());
break;
}
case kSparse: {
const auto *sparse_info = static_cast<SparseInfo *>(column_vector.data_type()->type_info().get());
SharedPtr<ConstantExpr> const_expr = BuildConstantSparseExprFromJson(line_json[column_def->name_], sparse_info);
if (const_expr.get() == nullptr) {
RecoverableError(Status::ImportFileFormatError("Invalid json object."));
}
column_vector.AppendByConstantExpr(const_expr.get());
break;
}
Expand Down
120 changes: 120 additions & 0 deletions src/function/cast/embedding_cast.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import status;
import logical_type;
import internal_types;
import embedding_info;
import sparse_info;
import knn_expr;
import data_type;
import default_values;
Expand All @@ -50,6 +51,12 @@ export inline BoundCastFunc BindEmbeddingCast(const DataType &source, const Data
if (source.type() == LogicalType::kEmbedding && target.type() == LogicalType::kTensor) {
return BoundCastFunc(&ColumnVectorCast::TryCastColumnVectorToVarlenWithType<EmbeddingT, TensorT, EmbeddingTryCastToVarlen>);
}
if (source.type() == LogicalType::kEmbedding && target.type() == LogicalType::kSparse) {
auto *sparse_info = static_cast<const SparseInfo *>(target.type_info().get());
if (sparse_info->DataType() == EmbeddingDataType::kElemBit) {
return BoundCastFunc(&ColumnVectorCast::TryCastColumnVectorToVarlenWithType<EmbeddingT, SparseT, EmbeddingTryCastToVarlen>);
}
}

if (source.type() != LogicalType::kEmbedding || target.type() != LogicalType::kEmbedding) {
Status status = Status::NotSupportedTypeConversion(source.ToString(), target.ToString());
Expand Down Expand Up @@ -377,4 +384,117 @@ inline bool EmbeddingTryCastToVarlen::Run(const EmbeddingT &source,
return true;
}

template <typename IdxT, typename SourceType>
void EmbeddingTryCastToSparseImpl(const EmbeddingT &source,
const EmbeddingInfo *source_info,
SparseT &target,
const SparseInfo *target_info,
ColumnVector *target_vector_ptr) {
SizeT source_dim = source_info->Dimension();
auto target_max_dim = static_cast<SourceType>(target_info->Dimension());
{
HashSet<IdxT> idx_set;
const auto *source_ptr = reinterpret_cast<const SourceType *>(source.ptr);
for (SizeT i = 0; i < source_dim; ++i) {
if (source_ptr[i] >= target_max_dim || source_ptr[i] < 0) {
RecoverableError(
Status::DataTypeMismatch(fmt::format("{} with data {}", source_info->ToString(), source_ptr[i]), target_info->ToString()));
}
auto [_, inserted] = idx_set.insert(static_cast<IdxT>(source_ptr[i]));
if (!inserted) {
RecoverableError(Status::InvalidDataType());
}
}
}

target.nnz_ = source_dim;
if constexpr (std::is_same_v<IdxT, SourceType>) {
SizeT source_size = source_dim * sizeof(SourceType);
const auto [chunk_id, chunk_offset] = target_vector_ptr->buffer_->fix_heap_mgr_->AppendToHeap(source.ptr, source_size);
target.chunk_id_ = chunk_id;
target.chunk_offset_ = chunk_offset;
} else {
SizeT target_size = source_dim * sizeof(IdxT);
auto target_tmp_ptr = MakeUniqueForOverwrite<IdxT[]>(source_dim);
if (!EmbeddingTryCastToFixlen::Run(reinterpret_cast<const SourceType *>(source.ptr), target_tmp_ptr.get(), source_dim)) {
UnrecoverableError(fmt::format("Failed to cast from embedding with type {} to sparse with type {}",
DataType::TypeToString<SourceType>(),
DataType::TypeToString<IdxT>()));
}
auto [chunk_id, chunk_offset] =
target_vector_ptr->buffer_->fix_heap_mgr_->AppendToHeap(reinterpret_cast<const char *>(target_tmp_ptr.get()), target_size);
target.chunk_id_ = chunk_id;
target.chunk_offset_ = chunk_offset;
}
}

template <typename IdxT>
void EmbeddingTryCastToSparse(const EmbeddingT &source,
const EmbeddingInfo *source_info,
SparseT &target,
const SparseInfo *target_info,
ColumnVector *target_vector_ptr) {
switch (source_info->Type()) {
case kElemInt8: {
EmbeddingTryCastToSparseImpl<IdxT, TinyIntT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
case kElemInt16: {
EmbeddingTryCastToSparseImpl<IdxT, SmallIntT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
case kElemInt32: {
EmbeddingTryCastToSparseImpl<IdxT, IntegerT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
case kElemInt64: {
EmbeddingTryCastToSparseImpl<IdxT, BigIntT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
default: {
UnrecoverableError(
fmt::format("Cannot cast from embedding with type {} to sparse", EmbeddingInfo::EmbeddingDataTypeToString(source_info->Type())));
}
}
}

template <>
inline bool EmbeddingTryCastToVarlen::Run(const EmbeddingT &source,
const DataType &source_type,
SparseT &target,
const DataType &target_type,
ColumnVector *target_vector_ptr) {
if (source_type.type() != LogicalType::kEmbedding) {
UnrecoverableError(fmt::format("Type here is expected as Embedding, but actually it is: {}", source_type.ToString()));
}
const auto *source_info = static_cast<EmbeddingInfo *>(source_type.type_info().get());
const auto *target_info = static_cast<SparseInfo *>(target_type.type_info().get());

if (target_info->DataType() != EmbeddingDataType::kElemBit) {
UnrecoverableError(fmt::format("No support data type: {}", EmbeddingType::EmbeddingDataType2String(target_info->IndexType())));
}
switch (target_info->IndexType()) {
case EmbeddingDataType::kElemInt8: {
EmbeddingTryCastToSparse<TinyIntT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
case EmbeddingDataType::kElemInt16: {
EmbeddingTryCastToSparse<SmallIntT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
case EmbeddingDataType::kElemInt32: {
EmbeddingTryCastToSparse<IntegerT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
case EmbeddingDataType::kElemInt64: {
EmbeddingTryCastToSparse<BigIntT>(source, source_info, target, target_info, target_vector_ptr);
break;
}
default: {
UnrecoverableError(fmt::format("No support data type: {}", EmbeddingType::EmbeddingDataType2String(target_info->IndexType())));
}
}
return true;
}

} // namespace infinity
30 changes: 20 additions & 10 deletions src/function/cast/sparse_cast.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ void SparseTryCastToSparseFunInner(const SparseInfo *source_info,
FixHeapManager *target_fix_heap_mgr) {
const auto &[source_nnz, source_chunk_id, source_chunk_offset] = source;
target.nnz_ = source_nnz;
if (source_nnz == 0) {
target.chunk_id_ = -1;
target.chunk_offset_ = 0;
return;
}
const_ptr_t source_ptr = source_fix_heap_mgr->GetRawPtrFromChunk(source_chunk_id, source_chunk_offset);
SizeT sparse_bytes = source_info->SparseSize(source_nnz);
if constexpr (std::is_same_v<TargetValueType, SourceValueType>) {
Expand Down Expand Up @@ -93,6 +98,7 @@ void SparseTryCastToSparseFunInner(const SparseInfo *source_info,
const SizeT source_indice_size = source_info->IndiceSize(source_nnz);
Vector<Pair<const_ptr_t, SizeT>> data_ptrs;
UniquePtr<TargetIndiceType[]> target_indice_tmp_ptr;
UniquePtr<TargetValueType[]> target_value_tmp_ptr;
if constexpr (std::is_same_v<TargetIndiceType, SourceIndiceType>) {
data_ptrs.emplace_back(reinterpret_cast<const char *>(source_ptr), source_indice_size);
} else {
Expand All @@ -107,16 +113,19 @@ void SparseTryCastToSparseFunInner(const SparseInfo *source_info,
}
data_ptrs.emplace_back(reinterpret_cast<const char *>(target_indice_tmp_ptr.get()), target_indice_size);
}
auto target_value_tmp_ptr = MakeUniqueForOverwrite<TargetValueType[]>(source_nnz);
const SizeT target_data_size = target_info->DataSize(source_nnz);
if (!EmbeddingTryCastToFixlen::Run(reinterpret_cast<const SourceValueType *>(source_ptr + source_indice_size),
reinterpret_cast<TargetValueType *>(target_value_tmp_ptr.get()),
source_nnz)) {
UnrecoverableError(fmt::format("Fail to case from sparse with idx {} to sparse with idx {}",
DataType::TypeToString<SourceValueType>(),
DataType::TypeToString<TargetValueType>()));

if constexpr (!std::is_same_v<TargetValueType, BooleanT>) {
target_value_tmp_ptr = MakeUniqueForOverwrite<TargetValueType[]>(source_nnz);
const SizeT target_data_size = target_info->DataSize(source_nnz);
if (!EmbeddingTryCastToFixlen::Run(reinterpret_cast<const SourceValueType *>(source_ptr + source_indice_size),
reinterpret_cast<TargetValueType *>(target_value_tmp_ptr.get()),
source_nnz)) {
UnrecoverableError(fmt::format("Fail to case from sparse with idx {} to sparse with idx {}",
DataType::TypeToString<SourceValueType>(),
DataType::TypeToString<TargetValueType>()));
}
data_ptrs.emplace_back(reinterpret_cast<const char *>(target_value_tmp_ptr.get()), target_data_size);
}
data_ptrs.emplace_back(reinterpret_cast<const char *>(target_value_tmp_ptr.get()), target_data_size);
std::tie(target.chunk_id_, target.chunk_offset_) = target_fix_heap_mgr->AppendToHeap(data_ptrs);
}
}
Expand Down Expand Up @@ -296,7 +305,8 @@ void SparseTryCastToSparseFun(const SparseInfo *source_info,
ColumnVector *target_vector_ptr) {
switch (target_info->DataType()) {
case kElemBit: {
UnrecoverableError("Unimplemented type");
SparseTryCastToSparseFunT1<BooleanT>(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
break;
}
case kElemInt8: {
SparseTryCastToSparseFunT1<TinyIntT>(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
Expand Down
4 changes: 3 additions & 1 deletion src/network/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ void Connection::SendTableDescription(const SharedPtr<DataTable> &result_table)
const auto *sparse_info = static_cast<SparseInfo *>(column_type->type_info().get());
switch (sparse_info->DataType()) {
case kElemBit: {
UnrecoverableError("Not implemented");
object_id = 1000;
object_width = 1;
break;
}
case kElemInt8: {
object_id = 1002;
Expand Down
Loading
Loading
0