8000 feat(storage): implement in memory FLAT vector index by wenym1 · Pull Request #21399 · risingwavelabs/risingwave · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat(storage): implement in memory FLAT vector index #21399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 77 additions & 9 deletions src/storage/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
use std::mem::take;
use std::ops::Bound::{Excluded, Included, Unbounded};
use std::ops::{Bound, RangeBounds};
use std::sync::{Arc, LazyLock};
Expand All @@ -34,16 +35,17 @@ use thiserror_ext::AsReport;
use tokio::task::yield_now;
use tracing::error;

use crate::dispatch_measurement;
use crate::error::StorageResult;
use crate::hummock::HummockError;
use crate::hummock::utils::{
do_delete_sanity_check, do_insert_sanity_check, do_update_sanity_check, merge_stream,
sanity_check_enabled,
};
use crate::mem_table::{KeyOp, MemTable};
use crate::panic_store::PanicStateStore;
use crate::storage_value::StorageValue;
use crate::store::*;
use crate::vector::{MeasureDistanceBuilder, NearestBuilder};

pub type BytesFullKey = FullKey<Bytes>;
pub type BytesFullKeyRange = (Bound<BytesFullKey>, Bound<BytesFullKey>);
Expand Down Expand Up @@ -257,13 +259,15 @@ pub mod sled {
RangeKvStateStore {
inner: SledRangeKv::new(path),
tables: Default::default(),
vectors: Default::default(),
}
}

pub fn new_temp() -> Self {
RangeKvStateStore {
inner: SledRangeKv::new_temp(),
tables: Default::default(),
vectors: Default::default(),
}
}
}
Expand Down Expand Up @@ -563,6 +567,8 @@ impl TableState {
}
}

type InMemVectorStore = Arc<RwLock<HashMap<TableId, Vec<(Vector, Bytes, u64)>>>>;

/// An in-memory state store
///
/// The in-memory state store is a [`BTreeMap`], which maps [`FullKey`] to value. It
Expand All @@ -574,6 +580,8 @@ pub struct RangeKvStateStore<R: RangeKv> {
inner: R,
/// `table_id` -> `prev_epoch` -> `curr_epoch`
tables: Arc<parking_lot::Mutex<HashMap<TableId, TableState>>>,

vectors: InMemVectorStore,
}

fn to_full_key_range<R, B>(table_id: TableId, table_key_range: R) -> BytesFullKeyRange
Expand Down Expand Up @@ -722,11 +730,42 @@ impl<R: RangeKv> StateStoreRead for RangeKvStateStoreReadSnapshot<R> {
impl<R: RangeKv> StateStoreReadVector for RangeKvStateStoreReadSnapshot<R> {
async fn nearest<O: Send + 'static>(
&self,
_vec: Vector,
_options: VectorNearestOptions,
_on_nearest_item_fn: impl OnNearestItemFn<O>,
vec: Vector,
options: VectorNearestOptions,
on_nearest_item_fn: impl OnNearestItemFn<O>,
) -> StorageResult<Vec<O>> {
unimplemented!()
fn nearest_impl<M: MeasureDistanceBuilder, O>(
store: &InMemVectorStore,
epoch: u64,
table_id: TableId,
vec: Vector,
options: VectorNearestOptions,
on_nearest_item_fn: impl OnNearestItemFn<O>,
) -> Vec<O> {
let mut builder = NearestBuilder::<'_, O, M>::new(vec.to_ref(), options.top_n);
builder.add(
store
.read()
.get(&table_id)
.map(|vec| vec.iter())
.into_iter()
.flatten()
.filter(|(_, _, vector_epoch)| epoch >= *vector_epoch)
.map(|(vec, info, _)| (vec.to_ref(), info.as_ref())),
on_nearest_item_fn,
);
builder.finish()
}
dispatch_measurement!(options.measure, MeasurementType, {
Ok(nearest_impl::<MeasurementType, O>(
&self.inner.vectors,
self.epoch,
self.table_id,
vec,
options,
on_nearest_item_fn,
))
})
}
}

Expand Down Expand Up @@ -884,12 +923,20 @@ impl<R: RangeKv> RangeKvStateStore<R> {
}))?;
Ok(size)
}

fn ingest_vectors(&self, table_id: TableId, epoch: u64, vecs: Vec<(Vector, Bytes)>) {
self.vectors
.write()
.entry(table_id)
.or_default()
.extend(vecs.into_iter().map(|(vec, info)| (vec, info, epoch)));
}
}

impl<R: RangeKv> StateStore for RangeKvStateStore<R> {
type Local = RangeKvLocalStateStore<R>;
type ReadSnapshot = RangeKvStateStoreReadSnapshot<R>;
type VectorWriter = PanicStateStore;
type VectorWriter = RangeKvLocalStateStore<R>;

async fn try_wait_epoch(
&self,
Expand All @@ -912,13 +959,23 @@ impl<R: RangeKv> StateStore for RangeKvStateStore<R> {
Ok(self.new_read_snapshot_impl(epoch.get_epoch(), options.table_id))
}

async fn new_vector_writer(&self, _options: NewVectorWriterOptions) -> Self::VectorWriter {
unimplemented!()
async fn new_vector_writer(&self, options: NewVectorWriterOptions) -> Self::VectorWriter {
RangeKvLocalStateStore::new(
self.clone(),
NewLocalOptions {
table_id: options.table_id,
op_consistency_level: Default::default(),
table_option: Default::default(),
is_replicated: false,
vnodes: Arc::new(Bitmap::from_bool_slice(&[true])),
},
)
}
}

pub struct RangeKvLocalStateStore<R: RangeKv> {
mem_table: MemTable,
vectors: Vec<(Vector, Bytes)>,
inner: RangeKvStateStore<R>,

epoch: Option<EpochPair>,
Expand All @@ -939,6 +996,7 @@ impl<R: RangeKv> RangeKvLocalStateStore<R> {
op_consistency_level: option.op_consistency_level,
table_option: option.table_option,
vnodes: option.vnodes,
vectors: vec![],
}
}

Expand Down Expand Up @@ -1108,8 +1166,11 @@ impl<R: RangeKv> StateStoreWriteEpochControl for RangeKvLocalStateStore<R> {
}
}
}
let epoch = self.epoch();
self.inner
.ingest_batch(kv_pairs, vec![], self.epoch(), self.table_id)
.ingest_vectors(self.table_id, epoch, take(&mut self.vectors));
self.inner
.ingest_batch(kv_pairs, vec![], epoch, self.table_id)
}

async fn init(&mut self, options: InitOptions) -> StorageResult<()> {
Expand Down Expand Up @@ -1219,6 +1280,13 @@ impl<R: RangeKv> StateStoreWriteEpochControl for RangeKvLocalStateStore<R> {
}
}

impl<R: RangeKv> StateStoreWriteVector for RangeKvLocalStateStore<R> {
fn insert(&mut self, vec: Vector, info: Bytes) -> StorageResult<()> {
self.vectors.push((vec, info));
Ok(())
}
}

pub struct RangeKvStateStoreIter<R: RangeKv> {
inner: batched_iter::Iter<R>,

Expand Down
4 changes: 3 additions & 1 deletion src/storage/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use risingwave_pb::hummock::PbVnodeWatermark;
use crate::error::{StorageError, StorageResult};
use crate::hummock::CachePolicy;
use crate::monitor::{MonitoredStateStore, MonitoredStorageMetrics};
pub(crate) use crate::vector::{DistanceMeasurement, OnNearestItemFn, Vector};
pub(crate) use crate::vector::{DistanceMeasurement, OnNearestItem, Vector};

pub trait StaticSendSync = Send + Sync + 'static;

Expand Down Expand Up @@ -435,6 +435,8 @@ pub struct VectorNearestOptions {
pub measure: DistanceMeasurement,
}

pub trait OnNearestItemFn<O> = OnNearestItem<O> + Send + 'static;

pub trait StateStoreReadVector: StaticSendSync {
fn nearest<O: Send + 'static>(
&self,
Expand Down
Loading
Loading
0