From 655f62e3eaa9a282141879e8c131cb26e2d978e9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 16 Jun 2016 16:46:27 +0900 Subject: [PATCH 01/22] optimize to read primitive array elements in Dataset --- .../spark/sql/catalyst/expressions/Cast.scala | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 70fff5195625..90cedd05aa0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -835,8 +835,36 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val j = ctx.freshName("j") val values = ctx.freshName("values") + val isPrimitiveFrom = ctx.isPrimitiveType(fromType) + val isPrimitiveTo = ctx.isPrimitiveType(toType) + (c, evPrim, evNull) => - s""" + if (isPrimitiveFrom && isPrimitiveTo) { + // ensure no null in input and output arrays + val javaDTFrom = ctx.javaType(fromType) + val javaDTTo = ctx.javaType(toType) + if (javaDTFrom == javaDTTo) { + val boxedTypeTo = ctx.primitiveTypeName(javaDTTo) + s""" + final ${javaDTTo}[] $values = $c.to${boxedTypeTo}Array(); + $evPrim = new $arrayClass($values); + """ + } else { + s""" + final int $size = $c.numElements(); + final ${javaDTTo}[] $values = new ${javaDTTo}[$c.numElements()]; + for (int $j = 0; $j < $size; $j ++) { + ${ctx.javaType(fromType)} $fromElementPrim = + ${ctx.getValue(c, fromType, j)}; + ${castCode(ctx, fromElementPrim, + "false", toElementPrim, toElementNull, toType, elementCast)} + $values[$j] = $toElementPrim; + } + $evPrim = new $arrayClass($values); + """ + } + } else { + s""" final int $size = $c.numElements(); final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { @@ -856,7 +884,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } $evPrim = new $arrayClass($values); - """ + """ + } } private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { From 50ed1e2aebda6bfa1eaf86da1f8ce6e408f4fc61 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 2 Jul 2016 03:30:30 +0900 Subject: [PATCH 02/22] check whether ArrayType.isContainsNull is false --- .../spark/sql/catalyst/expressions/Cast.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 90cedd05aa0a..55a59547c444 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -477,7 +477,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DoubleType => castToDoubleCode(from) case array: ArrayType => - castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) + castArrayCode(from.asInstanceOf[ArrayType], array, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] @@ -823,8 +823,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s"$evPrim = (double) $c;" } - private[this] def castArrayCode( - fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { + private[this] def castArrayCode(fromArrayType: ArrayType, toArrayType: ArrayType, + ctx: CodegenContext): CastFunction = { + val fromType = fromArrayType.elementType + val toType = toArrayType.elementType val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") @@ -836,11 +838,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val values = ctx.freshName("values") val isPrimitiveFrom = ctx.isPrimitiveType(fromType) + val ArrayType(_, containsNullFrom) = fromArrayType val isPrimitiveTo = ctx.isPrimitiveType(toType) + val ArrayType(_, containsNullTo) = toArrayType + print(s"containsNullFrom=$containsNullFrom, containsNullTo=$containsNullTo\n") (c, evPrim, evNull) => - if (isPrimitiveFrom && isPrimitiveTo) { - // ensure no null in input and output arrays + if (isPrimitiveFrom && !containsNullFrom && isPrimitiveTo && !containsNullTo) { + // ensure no null in input and output primitive arrays val javaDTFrom = ctx.javaType(fromType) val javaDTTo = ctx.javaType(toType) if (javaDTFrom == javaDTTo) { @@ -889,8 +894,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { - val keysCast = castArrayCode(from.keyType, to.keyType, ctx) - val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) + val keysCast = castArrayCode(ArrayType(from.keyType, false), ArrayType(to.keyType, false), ctx) + val valuesCast = + castArrayCode(ArrayType(from.valueType, true), ArrayType(to.valueType, true), ctx) val mapClass = classOf[ArrayBasedMapData].getName From 2de5e1ebfaeb5ea7912ccb4d16f797f4d134c8cd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 2 Jul 2016 03:31:32 +0900 Subject: [PATCH 03/22] pass precise isNullable information --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e559f235c5a3..d3537a3b8719 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2040,7 +2040,7 @@ class Analyzer( fail(child, DateType, walkedTypePath) case (StringType, to: NumericType) => fail(child, to, walkedTypePath) - case _ => Cast(child, dataType.asNullable) + case _ => Cast(child, dataType) } } } From 6026a21270fc360ee867f25c51a295e13d66edb0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 2 Jul 2016 03:34:05 +0900 Subject: [PATCH 04/22] remove debug print --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 55a59547c444..22c90dfd9a58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -841,7 +841,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val ArrayType(_, containsNullFrom) = fromArrayType val isPrimitiveTo = ctx.isPrimitiveType(toType) val ArrayType(_, containsNullTo) = toArrayType - print(s"containsNullFrom=$containsNullFrom, containsNullTo=$containsNullTo\n") (c, evPrim, evNull) => if (isPrimitiveFrom && !containsNullFrom && isPrimitiveTo && !containsNullTo) { From f3df6c512f5c787e1dad332552285c23f7fca015 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 3 Jul 2016 14:44:26 +0900 Subject: [PATCH 05/22] fix test failures by precisely passing nullable information to cast --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d3537a3b8719..dca3ad20f63b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2040,7 +2040,9 @@ class Analyzer( fail(child, DateType, walkedTypePath) case (StringType, to: NumericType) => fail(child, to, walkedTypePath) - case _ => Cast(child, dataType) + case (from: ArrayType, to: ArrayType) if !from.containsNull => + Cast(child, dataType) + case _ => Cast(child, dataType.asNullable) } } } From f4c84484400315274bb99b12cf889164f460575b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 10 Jul 2016 17:26:06 +0900 Subject: [PATCH 06/22] add benchmark program --- .../benchmark/PrimitiveArrayBenchmark.scala | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala new file mode 100644 index 000000000000..3ab2a066ae47 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.util.{Arrays, Comparator, Random} + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.array.LongArray +import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.Sorter +import org.apache.spark.util.collection.unsafe.sort._ + +/** + * Benchmark to measure performance for accessing primitive arrays + * To run this: + * 1. Replace ignore(...) with test(...) + * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class PrimitiveArrayBenchmark extends BenchmarkBase { + + test("Read array in Dataset") { + import sparkSession.implicits._ + + val iters = 5 + val n = 1024 * 1024 + val rows = 15 + + val benchmark = new Benchmark("Read primnitive array", n) + + val rand = new Random(511) + val intDS = sparkSession.sparkContext.parallelize(0 until rows, 1) + .map(i => Array.tabulate(n)(i => i)).toDS() + intDS.count() // force to create ds + val lastElement = n - 1 + val randElement = rand.nextInt(lastElement) + + benchmark.addCase(s"Read int array in Dataset", numIters = iters)(iter => { + val idx0 = randElement + val idx1 = lastElement + intDS.map(a => a(0) + a(idx0) + a(idx1)).collect + }) + + val doubleDS = sparkSession.sparkContext.parallelize(0 until rows, 1) + .map(i => Array.tabulate(n)(i => i.toDouble)).toDS() + doubleDS.count() // force to create ds + + benchmark.addCase(s"Read double array in Dataset", numIters = iters)(iter => { + val idx0 = randElement + val idx1 = lastElement + doubleDS.map(a => a(0) + a(idx0) + a(idx1)).collect + }) + + benchmark.run() + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.4 + Intel(R) Core(TM) i5-5257U CPU @ 2.70GHz + + Read primnitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Read int array in Dataset 400 / 492 2.6 381.5 1.0X + Read double array in Dataset 788 / 870 1.3 751.4 0.5X + */ + } +} From f8fe24eb3864f100eb4d7e7ee3c00177f3024af0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 10 Jul 2016 23:10:18 +0900 Subject: [PATCH 07/22] fix test failures --- .../spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index 3ab2a066ae47..baa6453b7489 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.collection.unsafe.sort._ */ class PrimitiveArrayBenchmark extends BenchmarkBase { - test("Read array in Dataset") { + ignore("Read array in Dataset") { import sparkSession.implicits._ val iters = 5 From 76225028769e3fdbcece8cd3c9e906d672985de4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 11 Jul 2016 12:21:27 +0900 Subject: [PATCH 08/22] revert changes for improvement of cast code generation --- .../spark/sql/catalyst/expressions/Cast.scala | 48 +++---------------- 1 file changed, 7 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 22c90dfd9a58..70fff5195625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -477,7 +477,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DoubleType => castToDoubleCode(from) case array: ArrayType => - castArrayCode(from.asInstanceOf[ArrayType], array, ctx) + castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] @@ -823,10 +823,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s"$evPrim = (double) $c;" } - private[this] def castArrayCode(fromArrayType: ArrayType, toArrayType: ArrayType, - ctx: CodegenContext): CastFunction = { - val fromType = fromArrayType.elementType - val toType = toArrayType.elementType + private[this] def castArrayCode( + fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") @@ -837,38 +835,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val j = ctx.freshName("j") val values = ctx.freshName("values") - val isPrimitiveFrom = ctx.isPrimitiveType(fromType) - val ArrayType(_, containsNullFrom) = fromArrayType - val isPrimitiveTo = ctx.isPrimitiveType(toType) - val ArrayType(_, containsNullTo) = toArrayType - (c, evPrim, evNull) => - if (isPrimitiveFrom && !containsNullFrom && isPrimitiveTo && !containsNullTo) { - // ensure no null in input and output primitive arrays - val javaDTFrom = ctx.javaType(fromType) - val javaDTTo = ctx.javaType(toType) - if (javaDTFrom == javaDTTo) { - val boxedTypeTo = ctx.primitiveTypeName(javaDTTo) - s""" - final ${javaDTTo}[] $values = $c.to${boxedTypeTo}Array(); - $evPrim = new $arrayClass($values); - """ - } else { - s""" - final int $size = $c.numElements(); - final ${javaDTTo}[] $values = new ${javaDTTo}[$c.numElements()]; - for (int $j = 0; $j < $size; $j ++) { - ${ctx.javaType(fromType)} $fromElementPrim = - ${ctx.getValue(c, fromType, j)}; - ${castCode(ctx, fromElementPrim, - "false", toElementPrim, toElementNull, toType, elementCast)} - $values[$j] = $toElementPrim; - } - $evPrim = new $arrayClass($values); - """ - } - } else { - s""" + s""" final int $size = $c.numElements(); final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { @@ -888,14 +856,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } $evPrim = new $arrayClass($values); - """ - } + """ } private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { - val keysCast = castArrayCode(ArrayType(from.keyType, false), ArrayType(to.keyType, false), ctx) - val valuesCast = - castArrayCode(ArrayType(from.valueType, true), ArrayType(to.valueType, true), ctx) + val keysCast = castArrayCode(from.keyType, to.keyType, ctx) + val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) val mapClass = classOf[ArrayBasedMapData].getName From 90b6542519c4bab56cd230bd54244c6821d68255 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 11 Jul 2016 12:21:46 +0900 Subject: [PATCH 09/22] fix typo --- .../sql/execution/benchmark/PrimitiveArrayBenchmark.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index baa6453b7489..6916959cfc93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -45,7 +45,7 @@ class PrimitiveArrayBenchmark extends BenchmarkBase { val n = 1024 * 1024 val rows = 15 - val benchmark = new Benchmark("Read primnitive array", n) + val benchmark = new Benchmark("Read primitive array", n) val rand = new Random(511) val intDS = sparkSession.sparkContext.parallelize(0 until rows, 1) @@ -75,7 +75,7 @@ class PrimitiveArrayBenchmark extends BenchmarkBase { Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.4 Intel(R) Core(TM) i5-5257U CPU @ 2.70GHz - Read primnitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Read primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Read int array in Dataset 400 / 492 2.6 381.5 1.0X Read double array in Dataset 788 / 870 1.3 751.4 0.5X From 7615342af0e9011848c7461e12726ad56538f2e9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 11 Jul 2016 18:42:33 +0900 Subject: [PATCH 10/22] eliminate unnecessary cast with non-nullable at SimplifyCasts --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dca3ad20f63b..e559f235c5a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2040,8 +2040,6 @@ class Analyzer( fail(child, DateType, walkedTypePath) case (StringType, to: NumericType) => fail(child, to, walkedTypePath) - case (from: ArrayType, to: ArrayType) if !from.containsNull => - Cast(child, dataType) case _ => Cast(child, dataType.asNullable) } } From 4a198b29be77f703ac7cb6d8cfc463029e219bbd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 11 Jul 2016 19:17:13 +0900 Subject: [PATCH 11/22] simplify case statements --- .../apache/spark/sql/catalyst/optimizer/expressions.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 74dfd10189d8..82ab111aa225 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,6 +475,12 @@ case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { object SimplifyCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Cast(e, dataType) if e.dataType == dataType => e + case c @ Cast(e, dataType) => (e.dataType, dataType) match { + case (ArrayType(from, false), ArrayType(to, true)) if from == to => e + case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true)) + if fromKey == toKey && fromValue == toValue => e + case _ => c + } } } From 6c7904ee8fd94e2cbc6fbbf11e01f8f8e6e63ea2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jul 2016 01:34:14 +0900 Subject: [PATCH 12/22] add test cases --- .../optimizer/SimplifyCastsSuite.scala | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala new file mode 100644 index 000000000000..0ee25e88fe58 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class SimplifyCastsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil + } + + test("non-nullable to non-nullable array cast") { + val input = LocalRelation('a.array(ArrayType(IntegerType, false))) + val array_intPrimitive = Literal.create( + Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, false)) + val plan = input.select(array_intPrimitive + .cast(ArrayType(IntegerType, false)).as('a)).analyze + val optimized = Optimize.execute(plan) + val expected = input.select(array_intPrimitive.as('a)).analyze + comparePlans(optimized, expected) + } + + test("non-nullable to nullable array cast") { + val input = LocalRelation('a.array(ArrayType(IntegerType, false))) + val array_intPrimitive = Literal.create( + Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, false)) + val plan = input.select(array_intPrimitive + .cast(ArrayType(IntegerType, true)).as('a)).analyze + val optimized = Optimize.execute(plan) + val expected = input.select(array_intPrimitive.as('a)).analyze + comparePlans(optimized, expected) + } + + test("non-nullable to non-nullable map cast") { + val input = LocalRelation('m.array(MapType(StringType, StringType, false))) + val map_notNull = Literal.create( + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, false)) + val plan = input.select(map_notNull + .cast(MapType(StringType, StringType, false)).as('m)).analyze + val optimized = Optimize.execute(plan) + val expected = input.select(map_notNull.as('m)).analyze + comparePlans(optimized, expected) + } + + test("non-nullable to nullable map cast") { + val input = LocalRelation('m.array(MapType(StringType, StringType, false))) + val map_notNull = Literal.create( + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, false)) + val plan = input.select(map_notNull + .cast(MapType(StringType, StringType, true)).as('m)).analyze + val optimized = Optimize.execute(plan) + val expected = input.select(map_notNull.as('m)).analyze + comparePlans(optimized, expected) + } +} \ No newline at end of file From 3fbe5e96662168e5b4355e364a3fb56d3c17f2d3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jul 2016 02:15:06 +0900 Subject: [PATCH 13/22] fix scala style error --- .../spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 0ee25e88fe58..3e7667bb74a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -74,4 +74,5 @@ class SimplifyCastsSuite extends PlanTest { val expected = input.select(map_notNull.as('m)).analyze comparePlans(optimized, expected) } -} \ No newline at end of file +} + From ee29c1bbc7d01c9aaa955d7b4fb82f983975cc3a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jul 2016 15:49:09 +0900 Subject: [PATCH 14/22] drop benchmark program --- .../benchmark/PrimitiveArrayBenchmark.scala | 84 ------------------- 1 file changed, 84 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala deleted file mode 100644 index 6916959cfc93..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.benchmark - -import java.util.{Arrays, Comparator, Random} - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock -import org.apache.spark.util.Benchmark -import org.apache.spark.util.collection.Sorter -import org.apache.spark.util.collection.unsafe.sort._ - -/** - * Benchmark to measure performance for accessing primitive arrays - * To run this: - * 1. Replace ignore(...) with test(...) - * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayBenchmark" - * - * Benchmarks in this file are skipped in normal builds. - */ -class PrimitiveArrayBenchmark extends BenchmarkBase { - - ignore("Read array in Dataset") { - import sparkSession.implicits._ - - val iters = 5 - val n = 1024 * 1024 - val rows = 15 - - val benchmark = new Benchmark("Read primitive array", n) - - val rand = new Random(511) - val intDS = sparkSession.sparkContext.parallelize(0 until rows, 1) - .map(i => Array.tabulate(n)(i => i)).toDS() - intDS.count() // force to create ds - val lastElement = n - 1 - val randElement = rand.nextInt(lastElement) - - benchmark.addCase(s"Read int array in Dataset", numIters = iters)(iter => { - val idx0 = randElement - val idx1 = lastElement - intDS.map(a => a(0) + a(idx0) + a(idx1)).collect - }) - - val doubleDS = sparkSession.sparkContext.parallelize(0 until rows, 1) - .map(i => Array.tabulate(n)(i => i.toDouble)).toDS() - doubleDS.count() // force to create ds - - benchmark.addCase(s"Read double array in Dataset", numIters = iters)(iter => { - val idx0 = randElement - val idx1 = lastElement - doubleDS.map(a => a(0) + a(idx0) + a(idx1)).collect - }) - - benchmark.run() - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.4 - Intel(R) Core(TM) i5-5257U CPU @ 2.70GHz - - Read primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Read int array in Dataset 400 / 492 2.6 381.5 1.0X - Read double array in Dataset 788 / 870 1.3 751.4 0.5X - */ - } -} From abdb5299606e4b8cb52095bac56b62671ad99ac0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jul 2016 15:50:09 +0900 Subject: [PATCH 15/22] add more test cases --- .../optimizer/SimplifyCastsSuite.scala | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 3e7667bb74a2..414833a2041c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -53,6 +53,27 @@ class SimplifyCastsSuite extends PlanTest { comparePlans(optimized, expected) } + test("nullable to non-nullable array cast") { + val input = LocalRelation('a.array(ArrayType(IntegerType, true))) + val array_intNull = Literal.create( + Seq(1, 2, null, 4, 5), ArrayType(IntegerType, true)) + val plan = input.select(array_intNull + .cast(ArrayType(IntegerType, false)).as('a)).analyze + val optimized = Optimize.execute(plan) + assert(optimized.resolved === false) + } + + test("nullable to nullable array cast") { + val input = LocalRelation('a.array(ArrayType(IntegerType, true))) + val array_intNull = Literal.create( + Seq(1, 2, null, 4, 5), ArrayType(IntegerType, true)) + val plan = input.select(array_intNull + .cast(ArrayType(IntegerType, true)).as('a)).analyze + val optimized = Optimize.execute(plan) + val expected = input.select(array_intNull.as('a)).analyze + comparePlans(optimized, expected) + } + test("non-nullable to non-nullable map cast") { val input = LocalRelation('m.array(MapType(StringType, StringType, false))) val map_notNull = Literal.create( @@ -74,5 +95,26 @@ class SimplifyCastsSuite extends PlanTest { val expected = input.select(map_notNull.as('m)).analyze comparePlans(optimized, expected) } + + test("nullable to non-nullable map cast") { + val input = LocalRelation('m.array(MapType(StringType, StringType, true))) + val map_Null = Literal.create( + Map("a" -> "123", "b" -> null, "c" -> "f"), MapType(StringType, StringType, true)) + val plan = input.select(map_Null + .cast(MapType(StringType, StringType, false)).as('m)).analyze + val optimized = Optimize.execute(plan) + assert(optimized.resolved === false) + } + + test("nullable to nullable map cast") { + val input = LocalRelation('m.array(MapType(StringType, StringType, true))) + val map_Null = Literal.create( + Map("a" -> "123", "b" -> null, "c" -> "f"), MapType(StringType, StringType, true)) + val plan = input.select(map_Null + .cast(MapType(StringType, StringType, true)).as('m)).analyze + val optimized = Optimize.execute(plan) + val expected = input.select(map_Null.as('m)).analyze + comparePlans(optimized, expected) + } } From 52e8bfdb7e7d03504a459743391acdbfb4eec9bb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jul 2016 18:29:03 +0900 Subject: [PATCH 16/22] update test cases --- .../spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 414833a2041c..98f9d8273d00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -60,7 +60,7 @@ class SimplifyCastsSuite extends PlanTest { val plan = input.select(array_intNull .cast(ArrayType(IntegerType, false)).as('a)).analyze val optimized = Optimize.execute(plan) - assert(optimized.resolved === false) + comparePlans(optimized, plan) } test("nullable to nullable array cast") { @@ -103,7 +103,7 @@ class SimplifyCastsSuite extends PlanTest { val plan = input.select(map_Null .cast(MapType(StringType, StringType, false)).as('m)).analyze val optimized = Optimize.execute(plan) - assert(optimized.resolved === false) + comparePlans(optimized, plan) } test("nullable to nullable map cast") { From ab0bed18fc1b4c208a3426f4afd2af2f3bd89a6e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jul 2016 19:43:33 +0900 Subject: [PATCH 17/22] update test cases --- .../catalyst/optimizer/SimplifyCastsSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 98f9d8273d00..1b695f2e6ce0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -32,7 +32,7 @@ class SimplifyCastsSuite extends PlanTest { } test("non-nullable to non-nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, false))) + val input = LocalRelation('a.array(ArrayType(IntegerType))) val array_intPrimitive = Literal.create( Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, false)) val plan = input.select(array_intPrimitive @@ -43,7 +43,7 @@ class SimplifyCastsSuite extends PlanTest { } test("non-nullable to nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, false))) + val input = LocalRelation('a.array(ArrayType(IntegerType))) val array_intPrimitive = Literal.create( Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, false)) val plan = input.select(array_intPrimitive @@ -54,7 +54,7 @@ class SimplifyCastsSuite extends PlanTest { } test("nullable to non-nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, true))) + val input = LocalRelation('a.array(ArrayType(IntegerType))) val array_intNull = Literal.create( Seq(1, 2, null, 4, 5), ArrayType(IntegerType, true)) val plan = input.select(array_intNull @@ -64,7 +64,7 @@ class SimplifyCastsSuite extends PlanTest { } test("nullable to nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, true))) + val input = LocalRelation('a.array(ArrayType(IntegerType))) val array_intNull = Literal.create( Seq(1, 2, null, 4, 5), ArrayType(IntegerType, true)) val plan = input.select(array_intNull @@ -75,7 +75,7 @@ class SimplifyCastsSuite extends PlanTest { } test("non-nullable to non-nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType, false))) + val input = LocalRelation('m.array(MapType(StringType, StringType))) val map_notNull = Literal.create( Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, false)) val plan = input.select(map_notNull @@ -86,7 +86,7 @@ class SimplifyCastsSuite extends PlanTest { } test("non-nullable to nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType, false))) + val input = LocalRelation('m.array(MapType(StringType, StringType))) val map_notNull = Literal.create( Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, false)) val plan = input.select(map_notNull @@ -97,7 +97,7 @@ class SimplifyCastsSuite extends PlanTest { } test("nullable to non-nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType, true))) + val input = LocalRelation('m.array(MapType(StringType, StringType))) val map_Null = Literal.create( Map("a" -> "123", "b" -> null, "c" -> "f"), MapType(StringType, StringType, true)) val plan = input.select(map_Null @@ -107,7 +107,7 @@ class SimplifyCastsSuite extends PlanTest { } test("nullable to nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType, true))) + val input = LocalRelation('m.array(MapType(StringType, StringType))) val map_Null = Literal.create( Map("a" -> "123", "b" -> null, "c" -> "f"), MapType(StringType, StringType, true)) val plan = input.select(map_Null From 02072049446e6a7fd415eff4f952a88429a459e3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 13 Jul 2016 01:14:18 +0900 Subject: [PATCH 18/22] update test cases --- .../optimizer/SimplifyCastsSuite.scala | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 1b695f2e6ce0..eb0aef5b525e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -31,10 +31,12 @@ class SimplifyCastsSuite extends PlanTest { val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil } + def array(arrayType: ArrayType): AttributeReference = + AttributeReference("a", arrayType)() + test("non-nullable to non-nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intPrimitive = Literal.create( - Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, false)) + val array_intPrimitive = array(ArrayType(IntegerType, false)) val plan = input.select(array_intPrimitive .cast(ArrayType(IntegerType, false)).as('a)).analyze val optimized = Optimize.execute(plan) @@ -44,8 +46,7 @@ class SimplifyCastsSuite extends PlanTest { test("non-nullable to nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intPrimitive = Literal.create( - Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, false)) + val array_intPrimitive = array(ArrayType(IntegerType, false)) val plan = input.select(array_intPrimitive .cast(ArrayType(IntegerType, true)).as('a)).analyze val optimized = Optimize.execute(plan) @@ -55,8 +56,7 @@ class SimplifyCastsSuite extends PlanTest { test("nullable to non-nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intNull = Literal.create( - Seq(1, 2, null, 4, 5), ArrayType(IntegerType, true)) + val array_intNull = array(ArrayType(IntegerType, true)) val plan = input.select(array_intNull .cast(ArrayType(IntegerType, false)).as('a)).analyze val optimized = Optimize.execute(plan) @@ -65,8 +65,7 @@ class SimplifyCastsSuite extends PlanTest { test("nullable to nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intNull = Literal.create( - Seq(1, 2, null, 4, 5), ArrayType(IntegerType, true)) + val array_intNull = array(ArrayType(IntegerType, true)) val plan = input.select(array_intNull .cast(ArrayType(IntegerType, true)).as('a)).analyze val optimized = Optimize.execute(plan) @@ -74,10 +73,15 @@ class SimplifyCastsSuite extends PlanTest { comparePlans(optimized, expected) } + def map(keyType: DataType, valueType: DataType, nullable: Boolean): AttributeReference = + map(MapType(keyType, valueType, nullable)) + + def map(mapType: MapType): AttributeReference = + AttributeReference("m", mapType)() + test("non-nullable to non-nullable map cast") { val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, false)) + val map_notNull = map(StringType, StringType, false) val plan = input.select(map_notNull .cast(MapType(StringType, StringType, false)).as('m)).analyze val optimized = Optimize.execute(plan) @@ -87,8 +91,7 @@ class SimplifyCastsSuite extends PlanTest { test("non-nullable to nullable map cast") { val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, false)) + val map_notNull = map(StringType, StringType, false) val plan = input.select(map_notNull .cast(MapType(StringType, StringType, true)).as('m)).analyze val optimized = Optimize.execute(plan) @@ -98,8 +101,7 @@ class SimplifyCastsSuite extends PlanTest { test("nullable to non-nullable map cast") { val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_Null = Literal.create( - Map("a" -> "123", "b" -> null, "c" -> "f"), MapType(StringType, StringType, true)) + val map_Null = map(StringType, StringType, true) val plan = input.select(map_Null .cast(MapType(StringType, StringType, false)).as('m)).analyze val optimized = Optimize.execute(plan) @@ -108,8 +110,7 @@ class SimplifyCastsSuite extends PlanTest { test("nullable to nullable map cast") { val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_Null = Literal.create( - Map("a" -> "123", "b" -> null, "c" -> "f"), MapType(StringType, StringType, true)) + val map_Null = map(StringType, StringType, true) val plan = input.select(map_Null .cast(MapType(StringType, StringType, true)).as('m)).analyze val optimized = Optimize.execute(plan) From d936cf0a87843b0904e05ff23f6b4eb0bfd46c77 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 13 Jul 2016 19:32:57 +0900 Subject: [PATCH 19/22] update test cases --- .../spark/sql/catalyst/dsl/package.scala | 3 +++ .../optimizer/SimplifyCastsSuite.scala | 20 +++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 9f54d709a022..8549187a6636 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -242,6 +242,9 @@ package object dsl { def array(dataType: DataType): AttributeReference = AttributeReference(s, ArrayType(dataType), nullable = true)() + def array(arrayType: ArrayType): AttributeReference = + AttributeReference(s, arrayType)() + /** Creates a new AttributeReference of type map */ def map(keyType: DataType, valueType: DataType): AttributeReference = map(MapType(keyType, valueType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index eb0aef5b525e..5bb61a62ebb6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -31,12 +32,9 @@ class SimplifyCastsSuite extends PlanTest { val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil } - def array(arrayType: ArrayType): AttributeReference = - AttributeReference("a", arrayType)() - test("non-nullable to non-nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intPrimitive = array(ArrayType(IntegerType, false)) + val input = LocalRelation('a.array(ArrayType(IntegerType, false))) + val array_intPrimitive = 'a.array(ArrayType(IntegerType, false)) val plan = input.select(array_intPrimitive .cast(ArrayType(IntegerType, false)).as('a)).analyze val optimized = Optimize.execute(plan) @@ -45,8 +43,8 @@ class SimplifyCastsSuite extends PlanTest { } test("non-nullable to nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intPrimitive = array(ArrayType(IntegerType, false)) + val input = LocalRelation('a.array(ArrayType(IntegerType, false))) + val array_intPrimitive = 'a.array(ArrayType(IntegerType, false)) val plan = input.select(array_intPrimitive .cast(ArrayType(IntegerType, true)).as('a)).analyze val optimized = Optimize.execute(plan) @@ -55,8 +53,8 @@ class SimplifyCastsSuite extends PlanTest { } test("nullable to non-nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intNull = array(ArrayType(IntegerType, true)) + val input = LocalRelation('a.array(ArrayType(IntegerType, true))) + val array_intNull = 'a.array(ArrayType(IntegerType, true)) val plan = input.select(array_intNull .cast(ArrayType(IntegerType, false)).as('a)).analyze val optimized = Optimize.execute(plan) @@ -64,8 +62,8 @@ class SimplifyCastsSuite extends PlanTest { } test("nullable to nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType))) - val array_intNull = array(ArrayType(IntegerType, true)) + val input = LocalRelation('a.array(ArrayType(IntegerType, true))) + val array_intNull = 'a.array(ArrayType(IntegerType, true)) val plan = input.select(array_intNull .cast(ArrayType(IntegerType, true)).as('a)).analyze val optimized = Optimize.execute(plan) From bff189a3e305b36739185b83e75d8abe0423f9db Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 19 Jul 2016 12:24:30 +0900 Subject: [PATCH 20/22] update test cases --- .../optimizer/SimplifyCastsSuite.scala | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 5bb61a62ebb6..0126c9dc5d5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -34,40 +34,33 @@ class SimplifyCastsSuite extends PlanTest { test("non-nullable to non-nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType, false))) - val array_intPrimitive = 'a.array(ArrayType(IntegerType, false)) - val plan = input.select(array_intPrimitive - .cast(ArrayType(IntegerType, false)).as('a)).analyze + val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select(array_intPrimitive.as('a)).analyze + val expected = input.select('a.as("casted")).analyze comparePlans(optimized, expected) } test("non-nullable to nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType, false))) val array_intPrimitive = 'a.array(ArrayType(IntegerType, false)) - val plan = input.select(array_intPrimitive - .cast(ArrayType(IntegerType, true)).as('a)).analyze + val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select(array_intPrimitive.as('a)).analyze + val expected = input.select('a.as("casted")).analyze comparePlans(optimized, expected) } test("nullable to non-nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType, true))) - val array_intNull = 'a.array(ArrayType(IntegerType, true)) - val plan = input.select(array_intNull - .cast(ArrayType(IntegerType, false)).as('a)).analyze + val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze val optimized = Optimize.execute(plan) comparePlans(optimized, plan) } test("nullable to nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType, true))) - val array_intNull = 'a.array(ArrayType(IntegerType, true)) - val plan = input.select(array_intNull - .cast(ArrayType(IntegerType, true)).as('a)).analyze + val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select(array_intNull.as('a)).analyze + val expected = input.select('a.as("casted")).analyze comparePlans(optimized, expected) } From c8f87a1a22dfa9c05aa2f514a7af43c046188b38 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 19 Jul 2016 23:12:08 +0900 Subject: [PATCH 21/22] update test cases --- .../optimizer/SimplifyCastsSuite.scala | 41 +++++++------------ 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 0126c9dc5d5d..1e7c49f49456 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -42,7 +42,6 @@ class SimplifyCastsSuite extends PlanTest { test("non-nullable to nullable array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType, false))) - val array_intPrimitive = 'a.array(ArrayType(IntegerType, false)) val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze val optimized = Optimize.execute(plan) val expected = input.select('a.as("casted")).analyze @@ -64,48 +63,38 @@ class SimplifyCastsSuite extends PlanTest { comparePlans(optimized, expected) } - def map(keyType: DataType, valueType: DataType, nullable: Boolean): AttributeReference = - map(MapType(keyType, valueType, nullable)) - - def map(mapType: MapType): AttributeReference = - AttributeReference("m", mapType)() - test("non-nullable to non-nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_notNull = map(StringType, StringType, false) - val plan = input.select(map_notNull - .cast(MapType(StringType, StringType, false)).as('m)).analyze + val input = LocalRelation('m.map(MapType(StringType, StringType, false))) + val plan = input.select('m.cast(MapType(StringType, StringType, false)) + .as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select(map_notNull.as('m)).analyze + val expected = input.select('m.as("casted")).analyze comparePlans(optimized, expected) } test("non-nullable to nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_notNull = map(StringType, StringType, false) - val plan = input.select(map_notNull - .cast(MapType(StringType, StringType, true)).as('m)).analyze + val input = LocalRelation('m.map(MapType(StringType, StringType, false))) + val plan = input.select('m.cast(MapType(StringType, StringType, true)) + .as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select(map_notNull.as('m)).analyze + val expected = input.select('m.as("casted")).analyze comparePlans(optimized, expected) } test("nullable to non-nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_Null = map(StringType, StringType, true) - val plan = input.select(map_Null - .cast(MapType(StringType, StringType, false)).as('m)).analyze + val input = LocalRelation('m.map(MapType(StringType, StringType, true))) + val plan = input.select('m.cast(MapType(StringType, StringType, false)) + .as("casted")).analyze val optimized = Optimize.execute(plan) comparePlans(optimized, plan) } test("nullable to nullable map cast") { - val input = LocalRelation('m.array(MapType(StringType, StringType))) - val map_Null = map(StringType, StringType, true) - val plan = input.select(map_Null - .cast(MapType(StringType, StringType, true)).as('m)).analyze + val input = LocalRelation('m.map(MapType(StringType, StringType, true))) + val plan = input.select('m.cast(MapType(StringType, StringType, true)) + .as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select(map_Null.as('m)).analyze + val expected = input.select('m.as("casted")).analyze comparePlans(optimized, expected) } } From 40ac2bcd2b3a631ca2121ed24137cbb197d1b278 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 30 Aug 2016 22:36:35 +0900 Subject: [PATCH 22/22] addressed comment --- .../optimizer/SimplifyCastsSuite.scala | 42 ++----------------- 1 file changed, 4 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 1e7c49f49456..e84f11272d21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -32,15 +32,7 @@ class SimplifyCastsSuite extends PlanTest { val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil } - test("non-nullable to non-nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, false))) - val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze - val optimized = Optimize.execute(plan) - val expected = input.select('a.as("casted")).analyze - comparePlans(optimized, expected) - } - - test("non-nullable to nullable array cast") { + test("non-nullable element array to nullable element array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType, false))) val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze val optimized = Optimize.execute(plan) @@ -48,31 +40,14 @@ class SimplifyCastsSuite extends PlanTest { comparePlans(optimized, expected) } - test("nullable to non-nullable array cast") { + test("nullable element to non-nullable element array cast") { val input = LocalRelation('a.array(ArrayType(IntegerType, true))) val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze val optimized = Optimize.execute(plan) comparePlans(optimized, plan) } - test("nullable to nullable array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, true))) - val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze - val optimized = Optimize.execute(plan) - val expected = input.select('a.as("casted")).analyze - comparePlans(optimized, expected) - } - - test("non-nullable to non-nullable map cast") { - val input = LocalRelation('m.map(MapType(StringType, StringType, false))) - val plan = input.select('m.cast(MapType(StringType, StringType, false)) - .as("casted")).analyze - val optimized = Optimize.execute(plan) - val expected = input.select('m.as("casted")).analyze - comparePlans(optimized, expected) - } - - test("non-nullable to nullable map cast") { + test("non-nullable value map to nullable value map cast") { val input = LocalRelation('m.map(MapType(StringType, StringType, false))) val plan = input.select('m.cast(MapType(StringType, StringType, true)) .as("casted")).analyze @@ -81,21 +56,12 @@ class SimplifyCastsSuite extends PlanTest { comparePlans(optimized, expected) } - test("nullable to non-nullable map cast") { + test("nullable value map to non-nullable value map cast") { val input = LocalRelation('m.map(MapType(StringType, StringType, true))) val plan = input.select('m.cast(MapType(StringType, StringType, false)) .as("casted")).analyze val optimized = Optimize.execute(plan) comparePlans(optimized, plan) } - - test("nullable to nullable map cast") { - val input = LocalRelation('m.map(MapType(StringType, StringType, true))) - val plan = input.select('m.cast(MapType(StringType, StringType, true)) - .as("casted")).analyze - val optimized = Optimize.execute(plan) - val expected = input.select('m.as("casted")).analyze - comparePlans(optimized, expected) - } }