Skip to content

Commit

Permalink
[ENH] Parameterized queries. (#3299)
Browse files Browse the repository at this point in the history
This introduces parameterized queries for selecting from the tiny
stories data set.
  • Loading branch information
rescrv authored Dec 13, 2024
1 parent 6b01bd5 commit 8779c3c
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 46 deletions.
8 changes: 4 additions & 4 deletions rust/load/src/bit_difference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ impl DataSet for SyntheticDataSet {
let collection = client.get_or_create_collection(&self.name(), None).await?;
let limit = gq.limit.sample(guac);
let mut ids = self.sample_ids(gq.skew, guac, limit);
let where_metadata = gq.metadata.map(|m| m.into_where_metadata(guac));
let where_document = gq.document.map(|m| m.into_where_document(guac));
let where_metadata = gq.metadata.map(|m| m.to_json(guac));
let where_document = gq.document.map(|m| m.to_json(guac));
let results = collection
.get(GetOptions {
ids: ids.clone(),
Expand Down Expand Up @@ -346,8 +346,8 @@ impl DataSet for SyntheticDataSet {
) -> Result<(), Box<dyn std::error::Error>> {
let collection = client.get_or_create_collection(&self.name(), None).await?;
let cluster = self.cluster_by_skew(vq.skew, guac);
let where_metadata = vq.metadata.map(|m| m.into_where_metadata(guac));
let where_document = vq.document.map(|m| m.into_where_document(guac));
let where_metadata = vq.metadata.map(|m| m.to_json(guac));
let where_document = vq.document.map(|m| m.to_json(guac));
let results = collection
.query(
QueryOptions {
Expand Down
8 changes: 4 additions & 4 deletions rust/load/src/data_sets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ impl DataSet for NopDataSet {
async fn query(
&self,
_: &ChromaClient,
_: QueryQuery,
qq: QueryQuery,
_: &mut Guacamole,
) -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("nop query");
tracing::info!("nop query {qq:?}", qq = qq);
Ok(())
}

Expand Down Expand Up @@ -113,8 +113,8 @@ impl DataSet for TinyStoriesDataSet {
) -> Result<(), Box<dyn std::error::Error>> {
let collection = client.get_collection(&self.name()).await?;
let limit = gq.limit.sample(guac);
let where_metadata = gq.metadata.map(|m| m.into_where_metadata(guac));
let where_document = gq.document.map(|m| m.into_where_document(guac));
let where_metadata = gq.metadata.map(|m| m.to_json(guac));
let where_document = gq.document.map(|m| m.to_json(guac));
let results = collection
.get(GetOptions {
ids: vec![],
Expand Down
101 changes: 70 additions & 31 deletions rust/load/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,44 +264,88 @@ impl PartialEq for Skew {
}
}

/////////////////////////////////////////// MetadataQuery //////////////////////////////////////////
///////////////////////////////////////// TinyStoriesMixin /////////////////////////////////////////

/// A metadata query specifies a metadata filter in Chroma.
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
pub enum MetadataQuery {
/// A raw metadata query simply copies the provided filter spec.
#[serde(rename = "raw")]
Raw(serde_json::Value),
#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
pub enum TinyStoriesMixin {
#[serde(rename = "numeric")]
Numeric { ratio_selected: f64 },
}

impl MetadataQuery {
/// Convert the metadata query into a JSON value suitable for use in a Chroma query.
pub fn into_where_metadata(self, _: &mut Guacamole) -> serde_json::Value {
impl TinyStoriesMixin {
pub fn to_json(&self, guac: &mut Guacamole) -> serde_json::Value {
match self {
MetadataQuery::Raw(json) => json,
Self::Numeric { ratio_selected } => {
let field: &'static str = match uniform(0u8, 5u8)(guac) {
0 => "i1",
1 => "i2",
2 => "i3",
3 => "f1",
4 => "f2",
5 => "f3",
_ => unreachable!(),
};
let mut center = uniform(0, 1_000_000)(guac);
let window = (1e6 * ratio_selected) as usize;
if window / 2 > center {
center = window / 2
}
let min = center - window / 2;
let max = center + window / 2;
serde_json::json!({"$and": [{field: {"$gte": min}}, {field: {"$lt": max}}]})
}
}
}
}

/////////////////////////////////////////// DocumentQuery //////////////////////////////////////////
//////////////////////////////////////////// WhereMixin ////////////////////////////////////////////

/// A document query specifies a document filter in Chroma.
#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
pub enum DocumentQuery {
// A raw document query simply copies the provided filter spec.
#[serde(rename = "raw")]
Raw(serde_json::Value),
/// A metadata query specifies a metadata filter in Chroma.
#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
pub enum WhereMixin {
/// A raw metadata query simply copies the provided filter spec.
#[serde(rename = "query")]
Constant(serde_json::Value),
/// The tiny stories workload. The way these collections were setup, there are three fields
/// each of integer, float, and string. The integer fields are named i1, i2, and i3. The
/// float fields are named f1, f2, and f3. The string fields are named s1, s2, and s3.
///
/// This mixin selects one of these 6 numeric fields at random and picks a metadata range query
/// to perform on it that will return data according to the mixin.
#[serde(rename = "tiny-stories")]
TinyStories(TinyStoriesMixin),
/// A constant operator with different comparison.
/// A mix of metadata queries selects one of the queries at random.
#[serde(rename = "select")]
Select(Vec<(f64, WhereMixin)>),
}

impl DocumentQuery {
/// Convert the document query into a JSON value suitable for use in a Chroma query.
pub fn into_where_document(self, _: &mut Guacamole) -> serde_json::Value {
impl WhereMixin {
/// Convert the metadata query into a JSON value suitable for use in a Chroma query.
pub fn to_json(&self, guac: &mut Guacamole) -> serde_json::Value {
match self {
DocumentQuery::Raw(json) => json,
Self::Constant(query) => query.clone(),
Self::TinyStories(mixin) => mixin.to_json(guac),
Self::Select(select) => {
let scale: f64 = any(guac);
let mut total = scale * select.iter().map(|(p, _)| *p).sum::<f64>();
for (p, mixin) in select {
if *p < 0.0 {
return serde_json::Value::Null;
}
if *p >= total {
return mixin.to_json(guac);
}
total -= *p;
}
serde_json::Value::Null
}
}
}
}

impl Eq for WhereMixin {}

///////////////////////////////////////////// GetQuery /////////////////////////////////////////////

/// A get query specifies a get operation in Chroma.
Expand All @@ -318,9 +362,9 @@ pub struct GetQuery {
pub skew: Skew,
pub limit: Distribution,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<MetadataQuery>,
pub metadata: Option<WhereMixin>,
#[serde(skip_serializing_if = "Option::is_none")]
pub document: Option<DocumentQuery>,
pub document: Option<WhereMixin>,
}

//////////////////////////////////////////// QueryQuery ////////////////////////////////////////////
Expand All @@ -339,9 +383,9 @@ pub struct QueryQuery {
pub skew: Skew,
pub limit: Distribution,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<MetadataQuery>,
pub metadata: Option<WhereMixin>,
#[serde(skip_serializing_if = "Option::is_none")]
pub document: Option<DocumentQuery>,
pub document: Option<WhereMixin>,
}

//////////////////////////////////////////// KeySelector ///////////////////////////////////////////
Expand Down Expand Up @@ -1505,13 +1549,11 @@ mod tests {
#[test]
fn workload_save_restore() {
const TEST_PATH: &str = "workload_save_restore.test.json";
println!("FINDME {}:{}", file!(), line!());
std::fs::remove_file(TEST_PATH).ok();
// First verse.
let mut load = LoadService::default();
load.set_persistent_path_and_load(Some(TEST_PATH.to_string()))
.unwrap();
println!("FINDME {}:{}", file!(), line!());
load.start(
"foo".to_string(),
"nop".to_string(),
Expand All @@ -1520,15 +1562,13 @@ mod tests {
Throughput::Constant(1.0),
)
.unwrap();
println!("FINDME {}:{}", file!(), line!());
let expected = {
// SAFETY(rescrv): Mutex poisoning.
let harness = load.harness.lock().unwrap();
assert_eq!(1, harness.running.len());
harness.running[0].clone()
};
drop(load);
println!("FINDME {}:{}", file!(), line!());
println!("expected: {:?}", expected);
// Second verse.
let mut load = LoadService::default();
Expand All @@ -1537,7 +1577,6 @@ mod tests {
let harness = load.harness.lock().unwrap();
assert!(harness.running.is_empty());
}
println!("FINDME {}:{}", file!(), line!());
load.set_persistent_path_and_load(Some(TEST_PATH.to_string()))
.unwrap();
{
Expand Down
26 changes: 19 additions & 7 deletions rust/load/src/workloads.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use crate::{
Distribution, DocumentQuery, GetQuery, KeySelector, MetadataQuery, QueryQuery, Skew, Workload,
Distribution, GetQuery, KeySelector, QueryQuery, Skew, TinyStoriesMixin, WhereMixin, Workload,
};

/// Return a map of all pre-configured workloads.
Expand All @@ -22,15 +22,19 @@ pub fn all_workloads() -> HashMap<String, Workload> {
skew: Skew::Zipf { theta: 0.999 },
limit: Distribution::Constant(10),
metadata: None,
document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))),
document: Some(WhereMixin::Constant(
serde_json::json!({"$contains": "the"}),
)),
}),
),
(
"get-metadata".to_string(),
Workload::Get(GetQuery {
skew: Skew::Zipf { theta: 0.999 },
limit: Distribution::Constant(10),
metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))),
metadata: Some(WhereMixin::TinyStories(TinyStoriesMixin::Numeric {
ratio_selected: 0.01,
})),
document: None,
}),
),
Expand All @@ -52,15 +56,19 @@ pub fn all_workloads() -> HashMap<String, Workload> {
skew: Skew::Zipf { theta: 0.999 },
limit: Distribution::Constant(10),
metadata: None,
document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))),
document: Some(WhereMixin::Constant(
serde_json::json!({"$contains": "the"}),
)),
}),
),
(
0.7,
Workload::Query(QueryQuery {
skew: Skew::Zipf { theta: 0.999 },
limit: Distribution::Constant(10),
metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))),
metadata: Some(WhereMixin::TinyStories(TinyStoriesMixin::Numeric {
ratio_selected: 0.01,
})),
document: None,
}),
),
Expand All @@ -75,15 +83,19 @@ pub fn all_workloads() -> HashMap<String, Workload> {
skew: Skew::Zipf { theta: 0.999 },
limit: Distribution::Constant(10),
metadata: None,
document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))),
document: Some(WhereMixin::Constant(
serde_json::json!({"$contains": "the"}),
)),
}),
),
(
0.25,
Workload::Get(GetQuery {
skew: Skew::Zipf { theta: 0.999 },
limit: Distribution::Constant(10),
metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))),
metadata: Some(WhereMixin::TinyStories(TinyStoriesMixin::Numeric {
ratio_selected: 0.01,
})),
document: None,
}),
),
Expand Down

0 comments on commit 8779c3c

Please sign in to comment.