diff --git a/native-sql-engine/core/src/test/scala/com/intel/oap/execution/PayloadSuite.scala b/native-sql-engine/core/src/test/scala/com/intel/oap/execution/PayloadSuite.scala index 87ce814ba..f878c877a 100644 --- a/native-sql-engine/core/src/test/scala/com/intel/oap/execution/PayloadSuite.scala +++ b/native-sql-engine/core/src/test/scala/com/intel/oap/execution/PayloadSuite.scala @@ -21,12 +21,14 @@ import java.nio.file.Files import com.intel.oap.tpc.util.TPCRunner import org.apache.log4j.{Level, LogManager} + import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions.{col, expr} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.PackageAccessor class PayloadSuite extends QueryTest with SharedSparkSession { @@ -75,20 +77,29 @@ class PayloadSuite extends QueryTest with SharedSparkSession { val lfile = Files.createTempFile("", ".parquet").toFile lfile.deleteOnExit() lPath = lfile.getAbsolutePath - spark.range(2).select(col("id"), expr("1").as("kind"), - expr("1").as("key"), - expr("array(1, 2)").as("arr_field"), - expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"), - expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"), - expr("array(map(1, 2), map(3,4))").as("arr_map_field"), - expr("struct(1, 2)").as("struct_field"), - expr("struct(1, struct(1, 2))").as("struct_struct_field"), - expr("struct(1, array(1, 2))").as("struct_array_field"), - expr("map(1, 2)").as("map_field"), - expr("map(1, map(3,4))").as("map_map_field"), - expr("map(1, array(1, 2))").as("map_arr_field"), - expr("map(struct(1, 2), 2)").as("map_struct_field")) - .coalesce(1) + val dfl = spark + .range(2) + .select( + col("id"), + expr("1").as("kind"), + expr("1").as("key"), + expr("array(1, 2)").as("arr_field"), + expr("array(\"hello\", \"world\")").as("arr_str_field"), + expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"), + expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"), + expr("array(map(1, 2), map(3,4))").as("arr_map_field"), + expr("struct(1, 2)").as("struct_field"), + expr("struct(1, struct(1, 2))").as("struct_struct_field"), + expr("struct(1, array(1, 2))").as("struct_array_field"), + expr("map(1, 2)").as("map_field"), + expr("map(1, map(3,4))").as("map_map_field"), + expr("map(1, array(1, 2))").as("map_arr_field"), + expr("map(struct(1, 2), 2)").as("map_struct_field")) + + // Arrow scan doesn't support converting from non-null nested type to nullable as of now + val dflNullable = dfl.sqlContext.createDataFrame(dfl.rdd, PackageAccessor.asNullable(dfl.schema)) + + dflNullable.coalesce(1) .write .format("parquet") .mode("overwrite") @@ -97,11 +108,19 @@ class PayloadSuite extends QueryTest with SharedSparkSession { val rfile = Files.createTempFile("", ".parquet").toFile rfile.deleteOnExit() rPath = rfile.getAbsolutePath - spark.range(2).select(col("id"), expr("id % 2").as("kind"), - expr("id % 2").as("key"), - expr("array(1, 2)").as("arr_field"), - expr("struct(1, 2)").as("struct_field")) - .coalesce(1) + + val dfr = spark.range(2) + .select( + col("id"), + expr("id % 2").as("kind"), + expr("id % 2").as("key"), + expr("array(1, 2)").as("arr_field"), + expr("struct(1, 2)").as("struct_field")) + + // Arrow scan doesn't support converting from non-null nested type to nullable as of now + val dfrNullable = dfr.sqlContext.createDataFrame(dfr.rdd, PackageAccessor.asNullable(dfr.schema)) + + dfrNullable.coalesce(1) .write .format("parquet") .mode("overwrite") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleSQLSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleSQLSuite.scala index f4ce2120a..371697e22 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleSQLSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleSQLSuite.scala @@ -73,20 +73,28 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession { val lfile = Files.createTempFile("", ".parquet").toFile lfile.deleteOnExit() lPath = lfile.getAbsolutePath - spark.range(2).select(col("id"), expr("1").as("kind"), - expr("array(1, 2)").as("arr_field"), - expr("array(\"hello\", \"world\")").as("arr_str_field"), - expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"), - expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"), - expr("array(map(1, 2), map(3,4))").as("arr_map_field"), - expr("struct(1, 2)").as("struct_field"), - expr("struct(1, struct(1, 2))").as("struct_struct_field"), - expr("struct(1, array(1, 2))").as("struct_array_field"), - expr("map(1, 2)").as("map_field"), - expr("map(1, map(3,4))").as("map_map_field"), - expr("map(1, array(1, 2))").as("map_arr_field"), - expr("map(struct(1, 2), 2)").as("map_struct_field")) - .coalesce(1) + val dfl = spark + .range(2) + .select( + col("id"), + expr("1").as("kind"), + expr("array(1, 2)").as("arr_field"), + expr("array(\"hello\", \"world\")").as("arr_str_field"), + expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"), + expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"), + expr("array(map(1, 2), map(3,4))").as("arr_map_field"), + expr("struct(1, 2)").as("struct_field"), + expr("struct(1, struct(1, 2))").as("struct_struct_field"), + expr("struct(1, array(1, 2))").as("struct_array_field"), + expr("map(1, 2)").as("map_field"), + expr("map(1, map(3,4))").as("map_map_field"), + expr("map(1, array(1, 2))").as("map_arr_field"), + expr("map(struct(1, 2), 2)").as("map_struct_field")) + + // Arrow scan doesn't support converting from non-null nested type to nullable as of now + val dflNullable = dfl.sqlContext.createDataFrame(dfl.rdd, dfl.schema.asNullable) + + dflNullable.coalesce(1) .write .format("parquet") .mode("overwrite") @@ -95,10 +103,18 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession { val rfile = Files.createTempFile("", ".parquet").toFile rfile.deleteOnExit() rPath = rfile.getAbsolutePath - spark.range(2).select(col("id"), expr("id % 2").as("kind"), - expr("array(1, 2)").as("arr_field"), - expr("struct(1, 2)").as("struct_field")) - .coalesce(1) + + val dfr = spark.range(2) + .select( + col("id"), + expr("id % 2").as("kind"), + expr("array(1, 2)").as("arr_field"), + expr("struct(1, 2)").as("struct_field")) + + // Arrow scan doesn't support converting from non-null nested type to nullable as of now + val dfrNullable = dfr.sqlContext.createDataFrame(dfr.rdd, dfr.schema.asNullable) + + dfrNullable.coalesce(1) .write .format("parquet") .mode("overwrite") diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/util/PackageAccessor.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/util/PackageAccessor.scala new file mode 100644 index 000000000..0aa981552 --- /dev/null +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/util/PackageAccessor.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.sql.types.StructType + +object PackageAccessor { + def asNullable(schema: StructType): StructType = { + schema.asNullable + } +}