Skip to content

Commit

Permalink
feat: support between sql clauses (#3225)
Browse files Browse the repository at this point in the history
This adds support for the sql `col BETWEEN x AND y` clause

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
connellPortrait and westonpace authored Dec 10, 2024
1 parent faf776d commit 7ec23f0
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 4 deletions.
1 change: 1 addition & 0 deletions python/python/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_sql_predicates(dataset):
("int >= 50", 50),
("int = 50", 1),
("int != 50", 99),
("int BETWEEN 50 AND 60", 11),
("float < 30.0", 45),
("str = 'aa'", 16),
("str in ('aa', 'bb')", 26),
Expand Down
19 changes: 19 additions & 0 deletions python/python/tests/test_scalar_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,25 @@ def test_indexed_scalar_scan(indexed_dataset: lance.LanceDataset, data_table: pa
assert actual_price == expected_price


def test_indexed_between(tmp_path):
dataset = lance.write_dataset(pa.table({"val": range(100)}), tmp_path)
dataset.create_scalar_index("val", index_type="BTREE")

scanner = dataset.scanner(filter="val BETWEEN 10 AND 20", prefilter=True)

assert "MaterializeIndex" in scanner.explain_plan()

actual_data = scanner.to_table()
assert actual_data.num_rows == 11

scanner = dataset.scanner(filter="val >= 10 AND val <= 20", prefilter=True)

assert "MaterializeIndex" in scanner.explain_plan()

actual_data = scanner.to_table()
assert actual_data.num_rows == 11


def test_temporal_index(tmp_path):
# Timestamps
now = datetime.now()
Expand Down
113 changes: 112 additions & 1 deletion rust/lance-datafusion/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion::sql::sqlparser::ast::{
};
use datafusion::{
common::Column,
logical_expr::{col, BinaryExpr, Like, Operator},
logical_expr::{col, Between, BinaryExpr, Like, Operator},
physical_expr::execution_props::ExecutionProps,
physical_plan::PhysicalExpr,
prelude::Expr,
Expand Down Expand Up @@ -746,6 +746,25 @@ impl Planner {
let field_access_expr = RawFieldAccessExpr { expr, field_access };
self.plan_field_access(field_access_expr)
}
SQLExpr::Between {
expr,
negated,
low,
high,
} => {
// Parse the main expression and bounds
let expr = self.parse_sql_expr(expr)?;
let low = self.parse_sql_expr(low)?;
let high = self.parse_sql_expr(high)?;

let between = Expr::Between(Between::new(
Box::new(expr),
*negated,
Box::new(low),
Box::new(high),
));
Ok(between)
}
_ => Err(Error::invalid_input(
format!("Expression '{expr}' is not supported SQL in lance"),
location!(),
Expand Down Expand Up @@ -1463,6 +1482,98 @@ mod tests {
}
}

#[test]
fn test_sql_between() {
use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use std::sync::Arc;

let schema = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Float64, false),
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
]));

let planner = Planner::new(schema.clone());

// Test integer BETWEEN
let expr = planner
.parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
.unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

// Create timestamp array with values representing:
// 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
let ts_array = TimestampMicrosecondArray::from_iter_values(
(0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
);

let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
Arc::new(ts_array),
],
)
.unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
false, false, false, true, true, true, true, true, false, false
])
);

// Test NOT BETWEEN
let expr = planner
.parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
.unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
true, true, true, false, false, false, false, false, true, true
])
);

// Test floating point BETWEEN
let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
false, false, false, true, true, true, true, false, false, false
])
);

// Test timestamp BETWEEN
let expr = planner
.parse_filter(
"ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
)
.unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
false, false, false, true, true, true, true, true, false, false
])
);
}

#[test]
fn test_sql_comparison() {
// Create a batch with all data types
Expand Down
120 changes: 117 additions & 3 deletions rust/lance-index/src/scalar/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use datafusion_expr::{
use futures::join;
use lance_core::{utils::mask::RowIdMask, Result};
use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
use log::warn;
use tracing::instrument;

use super::{AnyQuery, LabelListQuery, SargableQuery, ScalarIndex};
Expand Down Expand Up @@ -564,16 +565,85 @@ fn visit_comparison(
let scalar = maybe_scalar(&expr.right, col_type)?;
query_parser.visit_comparison(column, scalar, &expr.op)
} else {
let (column, col_type, query_parser) = maybe_indexed_column(&expr.right, index_info)?;
let scalar = maybe_scalar(&expr.left, col_type)?;
query_parser.visit_comparison(column, scalar, &expr.op)
// Datafusion's query simplifier will canonicalize expressions and so we shouldn't reach this case. If, for some reason, we
// do reach this case we can handle it in the future by inverting expr.op and swapping the left and right sides
warn!("Unexpected comparison encountered (DF simplifier should have removed this case). Scalar indices will not be applied");
None
}
}

fn maybe_between(expr: &BinaryExpr) -> Option<Between> {
let left_comparison = match expr.left.as_ref() {
Expr::BinaryExpr(binary_expr) => Some(binary_expr),
_ => None,
}?;
let right_comparison = match expr.right.as_ref() {
Expr::BinaryExpr(binary_expr) => Some(binary_expr),
_ => None,
}?;

match (left_comparison.op, right_comparison.op) {
(Operator::GtEq, Operator::LtEq) => {
// We have x >= y && a <= b.
// If x == a then it is a between query
// if y == b then it is a between query
if left_comparison.left == right_comparison.left {
Some(Between {
expr: left_comparison.left.clone(),
low: left_comparison.right.clone(),
high: right_comparison.right.clone(),
negated: false,
})
} else if left_comparison.right == right_comparison.right {
Some(Between {
expr: left_comparison.right.clone(),
low: right_comparison.left.clone(),
high: left_comparison.left.clone(),
negated: false,
})
} else {
None
}
}
(Operator::LtEq, Operator::GtEq) => {
// Same logic as above we just switch the low/high
if left_comparison.left == right_comparison.left {
Some(Between {
expr: left_comparison.left.clone(),
low: right_comparison.right.clone(),
high: left_comparison.right.clone(),
negated: false,
})
} else if left_comparison.right == right_comparison.right {
Some(Between {
expr: left_comparison.right.clone(),
low: left_comparison.left.clone(),
high: right_comparison.left.clone(),
negated: false,
})
} else {
None
}
}
_ => None,
}
}

fn visit_and(
expr: &BinaryExpr,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
// Many scalar indices can efficiently handle a BETWEEN query as a single search and this
// can be much more efficient than two separate range queries. As an optimization we check
// to see if this is a between query and, if so, we handle it as a single query
//
// Note: We can't rely on users writing the SQL BETWEEN operator because:
// * Some users won't realize it's an option or a good idea
// * Datafusion's simplifier will rewrite the BETWEEN operator into two separate range queries
if let Some(between) = maybe_between(expr) {
return visit_between(&between, index_info);
}

let left = visit_node(&expr.left, index_info);
let right = visit_node(&expr.right, index_info);
match (left, right) {
Expand Down Expand Up @@ -912,6 +982,7 @@ mod tests {
]);

check_no_index(&index_info, "size BETWEEN 5 AND 10");
// 5 different ways of writing BETWEEN (all should be recognized)
check_simple(
&index_info,
"aisle BETWEEN 5 AND 10",
Expand All @@ -921,6 +992,45 @@ mod tests {
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"aisle >= 5 AND aisle <= 10",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);

check_simple(
&index_info,
"aisle <= 10 AND aisle >= 5",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);

check_simple(
&index_info,
"5 <= aisle AND 10 >= aisle",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);

check_simple(
&index_info,
"10 >= aisle AND 5 <= aisle",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"on_sale IS TRUE",
Expand Down Expand Up @@ -1023,6 +1133,10 @@ mod tests {
Bound::Unbounded,
),
);
// In the future we can handle this case if we need to. For
// now let's make sure we don't accidentally do the wrong thing
// (we were getting this backwards in the past)
check_no_index(&index_info, "10 > aisle");
check_simple(
&index_info,
"aisle >= 10",
Expand Down

0 comments on commit 7ec23f0

Please sign in to comment.