diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b00b8ea553f2..2f7e0300f298 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -72,20 +72,49 @@ impl ScalarUDFImpl for GetFieldFunc { args.len() ); } - - let name = match &args[1] { - Expr::Literal(name) => name, - _ => { + let container_dt = &args[0].get_type(schema)?; + let accessor_expr = &args[1]; + match (container_dt, accessor_expr) { + (DataType::Struct(_), Expr::Literal(name)) => { + let access_schema = + GetFieldAccessSchema::NamedStructField { name: name.clone() }; + return access_schema + .get_accessed_field(container_dt) + .map(|f| f.data_type().clone()); + } + (DataType::Struct(_), _) => { return exec_err!( - "get_field function requires the argument field_name to be a string" + "get_field function, when accessing a Struct, requires the second argument to be a string literal, got {}", + accessor_expr ); } - }; - let access_schema = GetFieldAccessSchema::NamedStructField { name: name.clone() }; - let arg_dt = args[0].get_type(schema)?; - access_schema - .get_accessed_field(&arg_dt) - .map(|f| f.data_type().clone()) + (DataType::Map(fields, _), _) => { + // taken from field_util.rs GetFieldAccessSchema::get_accessed_field + match fields.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second columnis the "value" column both here and in + // execution. + let value_field = fields + .get(1) + .expect("fields should have exactly two members"); + return Ok(value_field.data_type().clone()); + } + _ => { + return exec_err!( + "Map fields must contain a Struct with exactly 2 fields" + ) + } + } + } + (dt, _) => { + return exec_err!( + "get_field function requires the first argument to be a Struct or a Map, got {}", + dt + ); + } + } } fn invoke(&self, args: &[ColumnarValue]) -> Result {