Skip to content

Commit

Permalink
fix: unwrap dictionaries in CreateNamedStruct (apache#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Aug 2, 2024
1 parent 39e030b commit 698c1b2
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use arrow::compute::take;
use arrow::record_batch::RecordBatch;
use arrow_array::types::Int32Type;
use arrow_array::{Array, DictionaryArray, StructArray};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_physical_expr::PhysicalExpr;
use std::{
any::Any,
fmt::{Display, Formatter},
hash::{Hash, Hasher},
sync::Arc,
};

use arrow::record_batch::RecordBatch;
use arrow_array::StructArray;
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_physical_expr::PhysicalExpr;

use crate::execution::datafusion::expressions::utils::down_cast_any_ref;

#[derive(Debug, Hash)]
Expand Down Expand Up @@ -63,6 +64,21 @@ impl PhysicalExpr for CreateNamedStruct {
.map(|expr| expr.evaluate(batch))
.collect::<datafusion_common::Result<Vec<_>>>()?;
let arrays = ColumnarValue::values_to_arrays(&values)?;
// TODO it would be more efficient if we could preserve dictionaries within the
// struct array but for now we unwrap them to avoid runtime errors
// https://github.com/apache/datafusion-comet/issues/755
let arrays = arrays
.iter()
.map(|array| {
if let Some(dict_array) =
array.as_any().downcast_ref::<DictionaryArray<Int32Type>>()
{
take(dict_array.values().as_ref(), dict_array.keys(), None)
} else {
Ok(Arc::clone(array))
}
})
.collect::<Result<Vec<_>, _>>()?;
let fields = match &self.data_type {
DataType::Struct(fields) => fields,
_ => {
Expand Down Expand Up @@ -125,3 +141,51 @@ impl PartialEq<dyn Any> for CreateNamedStruct {
.unwrap_or(false)
}
}

#[cfg(test)]
mod test {
use super::CreateNamedStruct;
use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Fields, Schema};
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr_common::expressions::column::Column;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

#[test]
fn test_create_struct_from_dict_encoded_i32() -> Result<()> {
let keys = Int32Array::from(vec![0, 1, 2]);
let values = Int32Array::from(vec![0, 111, 233]);
let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
let schema = Schema::new(vec![Field::new("a", data_type, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
let data_type =
DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int32, false)]));
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], data_type);
let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
unreachable!()
};
assert_eq!(3, x.len());
Ok(())
}

#[test]
fn test_create_struct_from_dict_encoded_string() -> Result<()> {
let keys = Int32Array::from(vec![0, 1, 2]);
let values = StringArray::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let schema = Schema::new(vec![Field::new("a", data_type, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
let data_type =
DataType::Struct(Fields::from(vec![Field::new("a", DataType::Utf8, false)]));
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], data_type);
let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
unreachable!()
};
assert_eq!(3, x.len());
Ok(())
}
}
16 changes: 16 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1887,4 +1887,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("struct and named_struct with dictionary") {
Seq(true, false).foreach { dictionaryEnabled =>
withParquetTable(
(0 until 100).map(i =>
(
i,
if (i % 2 == 0) { "even" }
else { "odd" })),
"tbl",
withDictionary = dictionaryEnabled) {
checkSparkAnswerAndOperator("SELECT struct(_1, _2) FROM tbl")
checkSparkAnswerAndOperator("SELECT named_struct('a', _1, 'b', _2) FROM tbl")
}
}
}
}

0 comments on commit 698c1b2

Please sign in to comment.