Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: nested list access #2966

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/read_and_write.rst
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,12 @@ For example, the following filter string is acceptable:

.. code-block:: SQL

((label IN [10, 20]) AND (note.email IS NOT NULL))
OR NOT note.created
((label IN [10, 20]) AND (note['email'] IS NOT NULL))
OR NOT note['created']

Nested fields can be accessed using the subscripts. Struct fields can be
subscripted using field names, while list fields can be subscripted using
indices.

If your column name contains special characters or is a `SQL Keyword <https://docs.rs/sqlparser/latest/sqlparser/keywords/index.html>`_,
you can use backtick (`````) to escape it. For nested fields, each segment of the
Expand Down
27 changes: 27 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,33 @@ def test_nested_projection(tmp_path: Path):
)


def test_nested_projection_list(tmp_path: Path):
table = pa.Table.from_pydict(
{
"a": range(100),
"b": range(100),
"list_struct": [
[{"x": counter, "y": counter % 2 == 0}] for counter in range(100)
],
}
)
base_dir = tmp_path / "test"
lance.write_dataset(table, base_dir)

dataset = lance.dataset(base_dir)

projected = dataset.to_table(columns={"list_struct": "list_struct[1]['x']"})
assert projected == pa.Table.from_pydict({"list_struct": range(100)})

# FIXME: sqlparser seems to ignore the .y part, but I can't create a simple
# reproducible example for sqlparser. Possibly an issue in our dialect.
# projected = dataset.to_table(
# columns={"list_struct": "array_element(list_struct, 1).y"})
# assert projected == pa.Table.from_pydict(
# {"list_struct": [i % 2 == 0 for i in range(100)]}
# )


def test_polar_scan(tmp_path: Path):
some_structs = [{"x": counter, "y": counter} for counter in range(100)]
table = pa.Table.from_pydict(
Expand Down
129 changes: 121 additions & 8 deletions rust/lance-datafusion/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::logical_expr::expr::ScalarFunction;
use datafusion::logical_expr::planner::ExprPlanner;
use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
use datafusion::logical_expr::{
AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowUDF,
AggregateUDF, ColumnarValue, GetFieldAccess, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
WindowUDF,
};
use datafusion::optimizer::simplify_expressions::SimplifyContext;
use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
use datafusion::sql::sqlparser::ast::{
Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo, Expr as SQLExpr,
Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, TimezoneInfo, UnaryOperator,
Value,
Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript, TimezoneInfo,
UnaryOperator, Value,
};
use datafusion::{
common::Column,
Expand Down Expand Up @@ -238,11 +239,15 @@ impl ContextProvider for LanceContextProvider {

pub struct Planner {
schema: SchemaRef,
context_provider: LanceContextProvider,
}

impl Planner {
pub fn new(schema: SchemaRef) -> Self {
Self { schema }
Self {
schema,
context_provider: LanceContextProvider::default(),
}
}

fn column(idents: &[Ident]) -> Expr {
Expand Down Expand Up @@ -403,9 +408,8 @@ impl Planner {
return self.legacy_parse_function(function);
}
}
let context_provider = LanceContextProvider::default();
let sql_to_rel = SqlToRel::new_with_options(
&context_provider,
&self.context_provider,
ParserOptions {
parse_float_as_decimal: false,
enable_ident_normalization: false,
Expand Down Expand Up @@ -516,6 +520,22 @@ impl Planner {
}
}

fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
for planner in self.context_provider.get_expr_planners() {
match planner.plan_field_access(field_access_expr, &df_schema)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(expr) => {
field_access_expr = expr;
}
}
}
Err(Error::invalid_input(
"Field access could not be planned",
location!(),
))
}

fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
match expr {
SQLExpr::Identifier(id) => {
Expand Down Expand Up @@ -665,6 +685,67 @@ impl Planner {
expr: Box::new(self.parse_sql_expr(expr)?),
data_type: self.parse_type(data_type)?,
})),
SQLExpr::MapAccess { column, keys } => {
let mut expr = self.parse_sql_expr(column)?;

for key in keys {
let field_access = match &key.key {
SQLExpr::Value(
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
) => GetFieldAccess::NamedStructField {
name: ScalarValue::from(s.as_str()),
},
SQLExpr::JsonAccess { .. } => {
return Err(Error::invalid_input(
"JSON access is not supported",
location!(),
));
}
key => {
let key = Box::new(self.parse_sql_expr(key)?);
GetFieldAccess::ListIndex { key }
Comment on lines +705 to +706
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to verify this is an integer at some point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in theory it could be an expression that resolves to an integer. I'm kind of trusting that it will be verified somewhere down the chain.

}
};

let field_access_expr = RawFieldAccessExpr { expr, field_access };

expr = self.plan_field_access(field_access_expr)?;
}

Ok(expr)
}
SQLExpr::Subscript { expr, subscript } => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat annoying that both Subscript and MapAccess exist 😕

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah and which path is triggered depends on the dialect. Implemented both for now.

let expr = self.parse_sql_expr(expr)?;

let field_access = match subscript.as_ref() {
Subscript::Index { index } => match index {
SQLExpr::Value(
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
) => GetFieldAccess::NamedStructField {
name: ScalarValue::from(s.as_str()),
},
SQLExpr::JsonAccess { .. } => {
return Err(Error::invalid_input(
"JSON access is not supported",
location!(),
));
}
_ => {
let key = Box::new(self.parse_sql_expr(index)?);
GetFieldAccess::ListIndex { key }
}
},
Subscript::Slice { .. } => {
return Err(Error::invalid_input(
"Slice subscript is not supported",
location!(),
));
}
};

let field_access_expr = RawFieldAccessExpr { expr, field_access };
self.plan_field_access(field_access_expr)
}
_ => Err(Error::invalid_input(
format!("Expression '{expr}' is not supported SQL in lance"),
location!(),
Expand Down Expand Up @@ -828,7 +909,10 @@ mod tests {
TimestampNanosecondArray, TimestampSecondArray,
};
use arrow_schema::{DataType, Fields, Schema};
use datafusion::logical_expr::{lit, Cast};
use datafusion::{
logical_expr::{lit, Cast},
prelude::{array_element, get_field},
};
use datafusion_functions::core::expr_ext::FieldAccessor;

#[test]
Expand Down Expand Up @@ -983,6 +1067,35 @@ mod tests {
assert_column_eq(&planner, "st.st.s2", &expected);
assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
assert_column_eq(&planner, "st.st.`s2`", &expected);
assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
}

#[test]
fn test_nested_list_refs() {
let schema = Arc::new(Schema::new(vec![Field::new(
"l",
DataType::List(Arc::new(Field::new(
"item",
DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
true,
))),
true,
)]));

let planner = Planner::new(schema.clone());

let expected = array_element(col("l"), lit(0_i64));
let expr = planner.parse_expr("l[0]").unwrap();
assert_eq!(expr, expected);

let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
let expr = planner.parse_expr("l[0]['f1']").unwrap();
assert_eq!(expr, expected);

// FIXME: This should work, but sqlparser doesn't recognize anything
// after the period for some reason.
// let expr = planner.parse_expr("l[0].f1").unwrap();
// assert_eq!(expr, expected);
}

#[test]
Expand Down
Loading