Skip to content

Commit

Permalink
feat: add optimizing_threads to function (#375)
Browse files Browse the repository at this point in the history
* feat: add optiming_threads to index

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

* put flexible at IndexProtect

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

* fix by comments

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>

---------

Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
  • Loading branch information
cutecutecat authored Mar 26, 2024
1 parent 9227095 commit e18616b
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 38 deletions.
43 changes: 33 additions & 10 deletions crates/base/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,39 @@ pub enum StatError {
NotExist,
}

#[must_use]
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
pub enum AlterError {
#[error("Setting key {key} is not exist.")]
BadKey { key: String },
#[error("Setting key {key} has a wrong value {value}.")]
BadValue { key: String, value: String },
#[error("Index not found.")]
NotExist,
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
pub struct IndexFlexibleOptions {
#[serde(default = "IndexFlexibleOptions::default_optimizing_threads")]
#[validate(range(min = 1, max = 65535))]
pub optimizing_threads: u16,
}

impl IndexFlexibleOptions {
pub fn default_optimizing_threads() -> u16 {
1
}
}

impl Default for IndexFlexibleOptions {
fn default() -> Self {
Self {
optimizing_threads: Self::default_optimizing_threads(),
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
#[validate(schema(function = "IndexOptions::validate_index_options"))]
Expand Down Expand Up @@ -198,9 +231,6 @@ pub struct OptimizingOptions {
#[serde(default = "OptimizingOptions::default_delete_threshold")]
#[validate(range(min = 0.01, max = 1.00))]
pub delete_threshold: f64,
#[serde(default = "OptimizingOptions::default_optimizing_threads")]
#[validate(range(min = 1, max = 65535))]
pub optimizing_threads: usize,
}

impl OptimizingOptions {
Expand All @@ -213,12 +243,6 @@ impl OptimizingOptions {
fn default_delete_threshold() -> f64 {
0.2
}
fn default_optimizing_threads() -> usize {
match std::thread::available_parallelism() {
Ok(threads) => (threads.get() as f64).sqrt() as _,
Err(_) => 1,
}
}
}

impl Default for OptimizingOptions {
Expand All @@ -227,7 +251,6 @@ impl Default for OptimizingOptions {
sealing_secs: Self::default_sealing_secs(),
sealing_size: Self::default_sealing_size(),
delete_threshold: Self::default_delete_threshold(),
optimizing_threads: Self::default_optimizing_threads(),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/base/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub trait WorkerOperations {
fn view_vbase(&self, handle: Handle) -> Result<impl ViewVbaseOperations, VbaseError>;
fn view_list(&self, handle: Handle) -> Result<impl ViewListOperations, ListError>;
fn stat(&self, handle: Handle) -> Result<IndexStat, StatError>;
fn alter(&self, handle: Handle, key: String, value: String) -> Result<(), AlterError>;
}

pub trait ViewBasicOperations {
Expand Down
3 changes: 3 additions & 0 deletions crates/common/src/clean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ pub fn clean(path: impl AsRef<Path>, wanted: impl Iterator<Item = String>) {
.unwrap();
let wanted = HashSet::<String>::from_iter(wanted);
for dir in dirs {
if dir.path().is_file() {
log::info!("Unexpected file {:?}, skip.", dir.path());
}
let filename = dir.file_name();
let filename = filename.to_str().unwrap();
let p = path.as_ref().join(filename);
Expand Down
36 changes: 36 additions & 0 deletions crates/index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use std::collections::HashMap;
use std::collections::HashSet;
use std::convert::Infallible;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Instant;
Expand Down Expand Up @@ -80,6 +81,7 @@ impl<O: Op> Index<O> {
IndexStartup {
sealeds: HashSet::new(),
growings: HashSet::new(),
flexible: IndexFlexibleOptions::default(),
},
);
let delete = Delete::create(path.join("delete"));
Expand All @@ -96,6 +98,7 @@ impl<O: Op> Index<O> {
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
flexible: IndexFlexibleOptions::default(),
sealed: HashMap::new(),
growing: HashMap::new(),
delete: delete.clone(),
Expand All @@ -116,6 +119,7 @@ impl<O: Op> Index<O> {
.unwrap();
let tracker = Arc::new(IndexTracker { path: path.clone() });
let startup = FileAtomic::<IndexStartup>::open(path.join("startup"));
let flexible = startup.get().flexible.clone();
clean(
path.join("segments"),
startup
Expand Down Expand Up @@ -170,6 +174,7 @@ impl<O: Op> Index<O> {
}),
view: ArcSwap::new(Arc::new(IndexView {
options: options.clone(),
flexible,
delete: delete.clone(),
sealed,
growing,
Expand All @@ -188,6 +193,23 @@ impl<O: Op> Index<O> {
pub fn view(&self) -> Arc<IndexView<O>> {
self.view.load_full()
}
pub fn alter(self: &Arc<Self>, key: String, value: String) -> Result<(), AlterError> {
let mut protect = self.protect.lock();
match key.as_str() {
"optimizing.threads" => {
let parsed = i32::from_str(value.as_str())
.map_err(|_e| AlterError::BadValue { key, value })?;
let optimizing_threads = match parsed {
0 => IndexFlexibleOptions::default_optimizing_threads(),
threads_limit => threads_limit as u16,
};
protect.flexible_set(IndexFlexibleOptions { optimizing_threads });
protect.maintain(self.options.clone(), self.delete.clone(), &self.view);
}
_ => return Err(AlterError::BadKey { key }),
};
Ok(())
}
pub fn refresh(&self) {
let mut protect = self.protect.lock();
if let Some((uuid, write)) = protect.write.clone() {
Expand Down Expand Up @@ -295,6 +317,7 @@ impl Drop for IndexTracker {

pub struct IndexView<O: Op> {
pub options: IndexOptions,
pub flexible: IndexFlexibleOptions,
pub delete: Arc<Delete>,
pub sealed: HashMap<Uuid, Arc<SealedSegment<O>>>,
pub growing: HashMap<Uuid, Arc<GrowingSegment<O>>>,
Expand Down Expand Up @@ -509,6 +532,7 @@ impl<O: Op> IndexView<O> {
struct IndexStartup {
sealeds: HashSet<Uuid>,
growings: HashSet<Uuid>,
flexible: IndexFlexibleOptions,
}

struct IndexProtect<O: Op> {
Expand All @@ -519,14 +543,17 @@ struct IndexProtect<O: Op> {
}

impl<O: Op> IndexProtect<O> {
/// Export IndexProtect to IndexView
fn maintain(
&mut self,
options: IndexOptions,
delete: Arc<Delete>,
swap: &ArcSwap<IndexView<O>>,
) {
let old_startup = self.startup.get();
let view = Arc::new(IndexView {
options,
flexible: old_startup.flexible.clone(),
delete,
sealed: self.sealed.clone(),
growing: self.growing.clone(),
Expand All @@ -538,7 +565,16 @@ impl<O: Op> IndexProtect<O> {
self.startup.set(IndexStartup {
sealeds: startup_sealeds,
growings: startup_growings,
flexible: old_startup.flexible.clone(),
});
swap.swap(view);
}
fn flexible_set(&mut self, flexible: IndexFlexibleOptions) {
let src = self.startup.get();
self.startup.set(IndexStartup {
sealeds: src.sealeds.clone(),
growings: src.sealeds.clone(),
flexible,
});
}
}
50 changes: 22 additions & 28 deletions crates/index/src/optimizing/indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use base::operator::Borrowed;
pub use base::search::*;
pub use base::vector::*;
use crossbeam::channel::RecvError;
use crossbeam::channel::TryRecvError;
use crossbeam::channel::{bounded, Receiver, RecvTimeoutError, Sender};
use std::cmp::Reverse;
use std::convert::Infallible;
Expand Down Expand Up @@ -102,34 +101,29 @@ impl<O: Op> OptimizerIndexing<O> {
}
fn main(self, shutdown: Receiver<Infallible>) {
let index = self.index;
rayon::ThreadPoolBuilder::new()
.num_threads(index.options.optimizing.optimizing_threads)
.build_scoped(|pool| {
std::thread::scope(|scope| {
scope.spawn(|| match shutdown.recv() {
Ok(never) => match never {},
Err(RecvError) => {
pool.stop();
}
});
loop {
if let Ok(()) = pool.install(|| optimizing_indexing(index.clone())) {
match shutdown.try_recv() {
Ok(never) => match never {},
Err(TryRecvError::Disconnected) => return,
Err(TryRecvError::Empty) => (),
}
continue;
}
match shutdown.recv_timeout(std::time::Duration::from_secs(60)) {
loop {
let view = index.view();
let threads = view.flexible.optimizing_threads;
rayon::ThreadPoolBuilder::new()
.num_threads(threads as usize)
.build_scoped(|pool| {
std::thread::scope(|scope| {
scope.spawn(|| match shutdown.recv() {
Ok(never) => match never {},
Err(RecvTimeoutError::Disconnected) => return,
Err(RecvTimeoutError::Timeout) => (),
}
}
});
})
.unwrap();
Err(RecvError) => {
pool.stop();
}
});
let _ = pool.install(|| optimizing_indexing(index.clone()));
})
})
.unwrap();
match shutdown.recv_timeout(std::time::Duration::from_secs(60)) {
Ok(never) => match never {},
Err(RecvTimeoutError::Disconnected) => return,
Err(RecvTimeoutError::Timeout) => (),
}
}
}
}

Expand Down
20 changes: 20 additions & 0 deletions crates/service/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ impl Instance {
Instance::Veci8Dot(x) => x.stat(),
}
}
pub fn alter(&self, key: String, value: String) -> Result<(), AlterError> {
match self {
Instance::Vecf32Cos(x) => x.alter(key, value),
Instance::Vecf32Dot(x) => x.alter(key, value),
Instance::Vecf32L2(x) => x.alter(key, value),
Instance::Vecf16Cos(x) => x.alter(key, value),
Instance::Vecf16Dot(x) => x.alter(key, value),
Instance::Vecf16L2(x) => x.alter(key, value),
Instance::SVecf32Cos(x) => x.alter(key, value),
Instance::SVecf32Dot(x) => x.alter(key, value),
Instance::SVecf32L2(x) => x.alter(key, value),
Instance::BVecf32Cos(x) => x.alter(key, value),
Instance::BVecf32Dot(x) => x.alter(key, value),
Instance::BVecf32L2(x) => x.alter(key, value),
Instance::BVecf32Jaccard(x) => x.alter(key, value),
Instance::Veci8L2(x) => x.alter(key, value),
Instance::Veci8Cos(x) => x.alter(key, value),
Instance::Veci8Dot(x) => x.alter(key, value),
}
}
pub fn start(&self) {
match self {
Instance::Vecf32Cos(x) => x.start(),
Expand Down
5 changes: 5 additions & 0 deletions crates/service/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ impl WorkerOperations for Worker {
let stat = instance.stat();
Ok(stat)
}
fn alter(&self, handle: Handle, key: String, value: String) -> Result<(), AlterError> {
let view = self.view();
let instance = view.get(handle).ok_or(AlterError::NotExist)?;
instance.alter(key, value)
}
}

pub struct WorkerView {
Expand Down
8 changes: 8 additions & 0 deletions src/bgworker/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ fn session(worker: Arc<Worker>, handler: ServerRpcHandler) -> Result<Infallible,
ServerRpcHandle::Stat { handle, x } => {
handler = x.leave(worker.stat(handle))?;
}
ServerRpcHandle::Alter {
handle,
key,
value,
x,
} => {
handler = x.leave(worker.alter(handle, key, value))?;
}
ServerRpcHandle::Basic {
handle,
vector,
Expand Down
11 changes: 11 additions & 0 deletions src/index/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@ use crate::error::*;
use crate::index::utils::from_oid_to_handle;
use crate::ipc::client;
use base::index::*;
use pgrx::error;

#[pgrx::pg_extern(volatile, strict)]
fn _vectors_alter_vector_index(oid: pgrx::pg_sys::Oid, key: String, value: String) {
let id = from_oid_to_handle(oid);
let mut rpc = check_client(client());
match rpc.alter(id, key, value) {
Ok(_) => {}
Err(e) => error!("{}", e.to_string()),
}
}

#[pgrx::pg_extern(volatile, strict, parallel_safe)]
fn _vectors_index_stat(
Expand Down
1 change: 1 addition & 0 deletions src/ipc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,5 @@ defines! {
stream vbase(handle: Handle, vector: OwnedVector, opts: SearchOptions) -> Pointer;
stream list(handle: Handle) -> Pointer;
unary stat(handle: Handle) -> IndexStat;
unary alter(handle: Handle, key: String, value: String) -> ();
}
3 changes: 3 additions & 0 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ BEGIN
END;
$$;

CREATE FUNCTION alter_vector_index("index" OID, "key" TEXT, "value" TEXT) RETURNS void
STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_alter_vector_index_wrapper';

-- List of casts

CREATE CAST (real[] AS vector)
Expand Down
37 changes: 37 additions & 0 deletions tests/sqllogictest/index_edit.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
statement ok
SET search_path TO pg_temp, vectors;

statement ok
CREATE TABLE t (val vector(3));

statement ok
INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000);

statement ok
CREATE INDEX hnsw_1 ON t USING vectors (val vector_l2_ops)
WITH (options = "[indexing.hnsw]");

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5]' limit 10) t2;
----
10

statement error does not exist
SELECT alter_vector_index('unknown_index'::regclass::oid, 'optimizing.threads', '1');

statement error Setting key
SELECT alter_vector_index('hnsw_1'::regclass::oid, 'unknown_key', '1');

statement error wrong value
SELECT alter_vector_index('hnsw_1'::regclass::oid, 'optimizing.threads', 'unknown_value');

statement ok
SELECT alter_vector_index('hnsw_1'::regclass::oid, 'optimizing.threads', '1');

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5]' limit 10) t2;
----
10

statement ok
DROP TABLE t;

0 comments on commit e18616b

Please sign in to comment.