Skip to content

Commit

Permalink
Make get_field UDF work for non-literal k in m[k] map lookup
Browse files Browse the repository at this point in the history
Specifically, make its return_type_from_exprs work for this case.
  • Loading branch information
vgapeyev committed May 23, 2024
1 parent f7430ad commit d911123
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,49 @@ impl ScalarUDFImpl for GetFieldFunc {
args.len()
);
}

let name = match &args[1] {
Expr::Literal(name) => name,
_ => {
let container_dt = &args[0].get_type(schema)?;
let accessor_expr = &args[1];
match (container_dt, accessor_expr) {
(DataType::Struct(_), Expr::Literal(name)) => {
let access_schema =
GetFieldAccessSchema::NamedStructField { name: name.clone() };
return access_schema
.get_accessed_field(container_dt)
.map(|f| f.data_type().clone());
}
(DataType::Struct(_), _) => {
return exec_err!(
"get_field function requires the argument field_name to be a string"
"get_field function, when accessing a Struct, requires the second argument to be a string literal, got {}",
accessor_expr
);
}
};
let access_schema = GetFieldAccessSchema::NamedStructField { name: name.clone() };
let arg_dt = args[0].get_type(schema)?;
access_schema
.get_accessed_field(&arg_dt)
.map(|f| f.data_type().clone())
(DataType::Map(fields, _), _) => {
// taken from field_util.rs GetFieldAccessSchema::get_accessed_field
match fields.data_type() {
DataType::Struct(fields) if fields.len() == 2 => {
// Arrow's MapArray is essentially a ListArray of structs with two columns. They are
// often named "key", and "value", but we don't require any specific naming here;
// instead, we assume that the second columnis the "value" column both here and in
// execution.
let value_field = fields
.get(1)
.expect("fields should have exactly two members");
return Ok(value_field.data_type().clone());
}
_ => {
return exec_err!(
"Map fields must contain a Struct with exactly 2 fields"
)
}
}
}
(dt, _) => {
return exec_err!(
"get_field function requires the first argument to be a Struct or a Map, got {}",
dt
);
}
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
Expand Down

0 comments on commit d911123

Please sign in to comment.