From 170e7de8823a5e1840d6dd96334c370bac0417fb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 10 Sep 2016 15:46:56 +0900 Subject: [PATCH 01/13] initial commit --- .../spark/sql/catalyst/ScalaReflection.scala | 9 +++ .../org/apache/spark/sql/DatasetSuite.scala | 12 ++++ .../benchmark/PrimitiveArrayBenchmark.scala | 72 +++++++++++++++++++ 3 files changed, 93 insertions(+) create mode 100755 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 31c6e5def143..a1f4fa34aa60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -441,6 +441,15 @@ object ScalaReflection extends ScalaReflection { val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath MapObjects(serializerFor(_, elementType, newPath), input, dt) + case dt @ (IntegerType | DoubleType) => + // case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + // FloatType | DoubleType) => + StaticInvoke( + classOf[UnsafeArrayData], + ArrayType(dt, false), + "fromPrimitiveArray", + input :: Nil) + case dt => NewInstance( classOf[GenericArrayData], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 55f04878052a..aa3cf736fbc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -987,6 +987,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) } } + + test("array") { + val arrayInt = Array(1, 2, 3) + val arrayDouble = Array(1.1, 2.2, 3.3) + val arrayString = Array("a", "b", "c") + val dsInt = sparkContext.parallelize(Seq(arrayInt), 1).toDS.map(e => e) + val dsDouble = sparkContext.parallelize(Seq(arrayDouble), 1).toDS.map(e => e) + val dsString = sparkContext.parallelize(Seq(arrayString), 1).toDS.map(e => e) + checkDataset(dsInt, arrayInt) + checkDataset(dsDouble, arrayDouble) + checkDataset(dsString, arrayString) + } } case class Generic[T](id: T, value: Double) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala new file mode 100755 index 000000000000..a6983e82cde2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array + * To run this: + * 1. replace ignore(...) with test(...) + * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayDataBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class PrimitiveArrayBenchmark extends BenchmarkBase { + + def readDataFrameArrayElement(iters: Int): Unit = { + import sparkSession.implicits._ + + val count = 1024 * 1024 * 24 + + val sc = sparkSession.sparkContext + val primitiveIntArray = Array.fill[Int](count)(1) + val dfInt = sc.parallelize(Seq(primitiveIntArray), 1).toDF + dfInt.count + val intArray = { i: Int => + var n = 0 + while (n < iters) { + dfInt.selectExpr("value[0]").count + n += 1 + } + } + val primitiveDoubleArray = Array.fill[Double](count)(1.0) + val dfDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDF + dfDouble.count + val doubleArray = { i: Int => + var n = 0 + while (n < iters) { + dfDouble.selectExpr("value[0]").count + n += 1 + } + } + + val benchmark = new Benchmark("Read an array element in DataFrame", count * iters) + benchmark.addCase("Int ")(intArray) + benchmark.addCase("Double")(doubleArray) + benchmark.run + } + + ignore("Read an array element in DataFrame") { + readDataFrameArrayElement(1) + } +} From 327b7ddb7cbf87bda2bc94da10dc59bea2b116fc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 11 Sep 2016 13:41:27 +0900 Subject: [PATCH 02/13] fix test failures --- .../spark/sql/catalyst/ScalaReflection.scala | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index a1f4fa34aa60..70cd37906884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -444,11 +444,20 @@ object ScalaReflection extends ScalaReflection { case dt @ (IntegerType | DoubleType) => // case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | // FloatType | DoubleType) => - StaticInvoke( - classOf[UnsafeArrayData], - ArrayType(dt, false), - "fromPrimitiveArray", - input :: Nil) + val cls = input.dataType.asInstanceOf[ObjectType].cls + if (cls.isAssignableFrom(classOf[Array[Int]]) || + cls.isAssignableFrom(classOf[Array[Double]])) { + StaticInvoke( + classOf[UnsafeArrayData], + ArrayType(dt, false), + "fromPrimitiveArray", + input :: Nil) + } else { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dt, schemaFor(elementType).nullable)) + } case dt => NewInstance( From e98cb1ebb7e582864c9622cd92ecc8606ded9076 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 5 Oct 2016 21:58:23 +0900 Subject: [PATCH 03/13] rebase add calling primitiveFromArray for other types --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 12 ++++++++---- .../scala/org/apache/spark/sql/DatasetSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 70cd37906884..4ed9f88d0464 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -441,11 +441,15 @@ object ScalaReflection extends ScalaReflection { val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath MapObjects(serializerFor(_, elementType, newPath), input, dt) - case dt @ (IntegerType | DoubleType) => - // case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | - // FloatType | DoubleType) => + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls - if (cls.isAssignableFrom(classOf[Array[Int]]) || + if (cls.isAssignableFrom(classOf[Array[Boolean]]) || + cls.isAssignableFrom(classOf[Array[Byte]]) || + cls.isAssignableFrom(classOf[Array[Short]]) || + cls.isAssignableFrom(classOf[Array[Int]]) || + cls.isAssignableFrom(classOf[Array[Long]]) || + cls.isAssignableFrom(classOf[Array[Float]]) || cls.isAssignableFrom(classOf[Array[Double]])) { StaticInvoke( classOf[UnsafeArrayData], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index aa3cf736fbc6..8b0554d5f0ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -989,13 +989,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("array") { + val arrayByte = Array(1.toByte, 2.toByte, 3.toByte) val arrayInt = Array(1, 2, 3) + val arrayLong = Array(1.toLong, 2.toLong, 3.toLong) val arrayDouble = Array(1.1, 2.2, 3.3) val arrayString = Array("a", "b", "c") + val dsByte = sparkContext.parallelize(Seq(arrayByte), 1).toDS.map(e => e) val dsInt = sparkContext.parallelize(Seq(arrayInt), 1).toDS.map(e => e) + val dsLong = sparkContext.parallelize(Seq(arrayLong), 1).toDS.map(e => e) val dsDouble = sparkContext.parallelize(Seq(arrayDouble), 1).toDS.map(e => e) val dsString = sparkContext.parallelize(Seq(arrayString), 1).toDS.map(e => e) + checkDataset(dsByte, arrayByte) checkDataset(dsInt, arrayInt) + checkDataset(dsLong, arrayLong) checkDataset(dsDouble, arrayDouble) checkDataset(dsString, arrayString) } From 3d7ea2cdde112cd2705e0e9bcc9f451892fe69ac Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 7 Oct 2016 01:57:00 +0900 Subject: [PATCH 04/13] update benchmark program --- .../benchmark/PrimitiveArrayBenchmark.scala | 153 +++++++++--------- 1 file changed, 81 insertions(+), 72 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index a6983e82cde2..b630871373b4 100755 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -1,72 +1,81 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.benchmark - -import scala.concurrent.duration._ - -import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.util.Benchmark - -/** - * Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array - * To run this: - * 1. replace ignore(...) with test(...) - * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayDataBenchmark" - * - * Benchmarks in this file are skipped in normal builds. - */ -class PrimitiveArrayBenchmark extends BenchmarkBase { - - def readDataFrameArrayElement(iters: Int): Unit = { - import sparkSession.implicits._ - - val count = 1024 * 1024 * 24 - - val sc = sparkSession.sparkContext - val primitiveIntArray = Array.fill[Int](count)(1) - val dfInt = sc.parallelize(Seq(primitiveIntArray), 1).toDF - dfInt.count - val intArray = { i: Int => - var n = 0 - while (n < iters) { - dfInt.selectExpr("value[0]").count - n += 1 - } - } - val primitiveDoubleArray = Array.fill[Double](count)(1.0) - val dfDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDF - dfDouble.count - val doubleArray = { i: Int => - var n = 0 - while (n < iters) { - dfDouble.selectExpr("value[0]").count - n += 1 - } - } - - val benchmark = new Benchmark("Read an array element in DataFrame", count * iters) - benchmark.addCase("Int ")(intArray) - benchmark.addCase("Double")(doubleArray) - benchmark.run - } - - ignore("Read an array element in DataFrame") { - readDataFrameArrayElement(1) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array + * To run this: + * 1. replace ignore(...) with test(...) + * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayDataBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class PrimitiveArrayBenchmark extends BenchmarkBase { + + def writeDatasetArray(iters: Int): Unit = { + import sparkSession.implicits._ + + val count = 1024 * 1024 * 2 + + val sc = sparkSession.sparkContext + val primitiveIntArray = Array.fill[Int](count)(65535) + val dsInt = sc.parallelize(Seq(primitiveIntArray), 1).toDS + dsInt.count + val intArray = { i: Int => + var n = 0 + while (n < iters) { + dsInt.map(e => e).collect + n += 1 + } + } + val primitiveDoubleArray = Array.fill[Double](count)(65535.0) + val dsDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDS + dsDouble.count + val doubleArray = { i: Int => + var n = 0 + var sum = 0L + while (n < iters) { + dsDouble.map(e => e).collect + n += 1 + } + } + + val benchmark = new Benchmark("Write an array in Dataset", count * iters) + benchmark.addCase("Int ")(intArray) + benchmark.addCase("Double")(doubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Int 280 / 302 30.0 33.4 1.0X + Double 503 / 519 16.7 60.0 0.6X + */ + } + + ignore("Write an array in Dataset") { + writeDatasetArray(4) + } +} From 5d5ccd6785ddcac1ef4be82aef9e8dcdbcbfab65 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 9 Oct 2016 01:33:53 +0900 Subject: [PATCH 05/13] create UnsafeArrayData from a primitive array in CatalystTypeConverters --- .../sql/catalyst/CatalystTypeConverters.scala | 14 +++++++ .../CatalystTypeConvertersSuite.scala | 37 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5b9161551a7a..9d32df5c133d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -158,6 +158,20 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { + case a: Array[Boolean] => + UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Byte] => + UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Short] => + UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Int] => + UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Long] => + UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Float] => + UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Double] => + UnsafeArrayData.fromPrimitiveArray(a) case a: Array[_] => new GenericArrayData(a.map(elementConverter.toCatalyst)) case s: Seq[_] => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 03bb102c67fe..68f2c31e1b13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -61,4 +63,39 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { test("option handling in createToCatalystConverter") { assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) } + + test("primitive array handing") { + val intArray = Array(1, 100, 10000) + val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray) + val intArrayType = ArrayType(IntegerType, false) + assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray) + assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray) + == intUnsafeArray) + + val doubleArray = Array(1.1, 111.1, 11111.1) + val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray) + val doubleArrayType = ArrayType(DoubleType, false) + assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray) + === doubleArray) + assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) + == doubleUnsafeArray) + } + + test("An array with null handing") { + val intArray = Array(1, null, 100, null, 10000) + val intGenericArray = new GenericArrayData(intArray) + val intArrayType = ArrayType(IntegerType, true) + assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intGenericArray) + === intArray) + assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray) + == intGenericArray) + + val doubleArray = Array(1.1, null, 111.1, null, 11111.1) + val doubleGenericArray = new GenericArrayData(doubleArray) + val doubleArrayType = ArrayType(DoubleType, true) + assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleGenericArray) + === doubleArray) + assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) + == doubleGenericArray) + } } From edfbce38c65dbb430f3e23aa0edbd0fd889958f0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 9 Oct 2016 02:55:20 +0900 Subject: [PATCH 06/13] create UnsafeArrayData from a primitive array in RowEncoder.serializeFor --- .../sql/catalyst/encoders/RowEncoder.scala | 50 ++++++++++++++----- .../expressions/objects/objects.scala | 8 ++- .../catalyst/encoders/RowEncoderSuite.scala | 26 ++++++++++ 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 2a6fcd03a26b..4ffcc84b6d92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -119,18 +119,33 @@ object RowEncoder { "fromString", inputObject :: Nil) - case t @ ArrayType(et, _) => et match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => - // TODO: validate input type for primitive array. - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = t) - case _ => MapObjects( - element => serializerFor(ValidateExternalType(element, et), et), - inputObject, - ObjectType(classOf[Object])) - } + case t @ ArrayType(et, cn) => + val cls = inputObject.dataType.asInstanceOf[ObjectType].cls + et match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType + if !cn && ( + cls.isAssignableFrom(classOf[Array[Boolean]]) || + cls.isAssignableFrom(classOf[Array[Byte]]) || + cls.isAssignableFrom(classOf[Array[Short]]) || + cls.isAssignableFrom(classOf[Array[Int]]) || + cls.isAssignableFrom(classOf[Array[Long]]) || + cls.isAssignableFrom(classOf[Array[Float]]) || + cls.isAssignableFrom(classOf[Array[Double]])) => + StaticInvoke( + classOf[UnsafeArrayData], + ArrayType(et, false), + "fromPrimitiveArray", + inputObject :: Nil) + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = t) + case _ => MapObjects( + element => serializerFor(ValidateExternalType(element, et), et), + inputObject, + ObjectType(classOf[Object])) + } case t @ MapType(kt, vt, valueNullable) => val keys = @@ -193,6 +208,17 @@ object RowEncoder { // as java.lang.Object. case _: DecimalType => ObjectType(classOf[java.lang.Object]) // In order to support both Array and Seq in external row, we make this as java.lang.Object. + case a @ ArrayType(et, cn) if !cn => + et match { + case BooleanType => ObjectType(classOf[Array[Boolean]]) + case ByteType => ObjectType(classOf[Array[Byte]]) + case ShortType => ObjectType(classOf[Array[Short]]) + case IntegerType => ObjectType(classOf[Array[Int]]) + case LongType => ObjectType(classOf[Array[Long]]) + case FloatType => ObjectType(classOf[Array[Float]]) + case DoubleType => ObjectType(classOf[Array[Double]]) + case _ => ObjectType(classOf[java.lang.Object]) + } case _: ArrayType => ObjectType(classOf[java.lang.Object]) case _ => externalDataTypeFor(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e2ac3c36d9..a33e53a3408f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -993,8 +993,12 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: DecimalType => Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") - case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + case a @ ArrayType(et, cn) => + if (!cn && ctx.isPrimitiveType(et)) { + s"$obj instanceof ${ctx.javaType(et)}[]" + } else { + s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + } case _ => s"$obj instanceof ${ctx.boxedType(dataType)}" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 2e513ea22c15..9224373d935f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -191,6 +191,32 @@ class RowEncoderSuite extends SparkFunSuite { assert(encoder.serializer.head.nullable == false) } + test("RowEncoder should support a primitive array") { + val schema = new StructType() + .add("booleanPrimitiveArray", ArrayType(BooleanType, false)) + .add("bytePrimitiveArray", ArrayType(ByteType, false)) + .add("shortPrimitiveArray", ArrayType(ShortType, false)) + .add("intPrimitiveArray", ArrayType(IntegerType, false)) + .add("longPrimitiveArray", ArrayType(LongType, false)) + .add("floatPrimitiveArray", ArrayType(FloatType, false)) + .add("doublePrimitiveArray", ArrayType(DoubleType, false)) + val encoder = RowEncoder(schema).resolveAndBind() + val input = Seq( + Array(true, false), + Array(1.toByte, 64.toByte, Byte.MaxValue), + Array(1.toShort, 255.toShort, Short.MaxValue), + Array(1, 10000, Int.MaxValue), + Array(1.toLong, 1000000.toLong, Long.MaxValue), + Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue), + Array(11.1111, 123456.7890123, Double.MaxValue) + ) + val row = encoder.toRow(Row.fromSeq(input)) + val convertedBack = encoder.fromRow(row) + input.zipWithIndex.map { case (array, index) => + assert(convertedBack.getSeq(index) === array) + } + } + test("RowEncoder should support array as the external type for ArrayType") { val schema = new StructType() .add("array", ArrayType(IntegerType)) From 8b79146c168b0e0226af256b5d967abc6b0dcafb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 14 Oct 2016 11:52:39 +0900 Subject: [PATCH 07/13] fix typo --- .../spark/sql/catalyst/CatalystTypeConvertersSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 68f2c31e1b13..257d20980356 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -64,7 +64,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) } - test("primitive array handing") { + test("primitive array handling") { val intArray = Array(1, 100, 10000) val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray) val intArrayType = ArrayType(IntegerType, false) @@ -81,7 +81,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { == doubleUnsafeArray) } - test("An array with null handing") { + test("An array with null handling") { val intArray = Array(1, null, 100, null, 10000) val intGenericArray = new GenericArrayData(intArray) val intArrayType = ArrayType(IntegerType, true) From b5473e3b6675c6bbfdd3fe261699f855faf39c94 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 3 Nov 2016 15:55:46 +0900 Subject: [PATCH 08/13] revert changes for ValdateExternType SPARK-18070 requests any ArrayType to accept an element-nullable array and an element-non-nullable array --- .../sql/catalyst/encoders/RowEncoder.scala | 17 +++++------------ .../catalyst/expressions/objects/objects.scala | 6 +----- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 4ffcc84b6d92..df03075207f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -122,6 +122,7 @@ object RowEncoder { case t @ ArrayType(et, cn) => val cls = inputObject.dataType.asInstanceOf[ObjectType].cls et match { +/* case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType if !cn && ( cls.isAssignableFrom(classOf[Array[Boolean]]) || @@ -131,11 +132,13 @@ object RowEncoder { cls.isAssignableFrom(classOf[Array[Long]]) || cls.isAssignableFrom(classOf[Array[Float]]) || cls.isAssignableFrom(classOf[Array[Double]])) => + print(s"1@ET: $et, $cn, $cls\n") StaticInvoke( classOf[UnsafeArrayData], ArrayType(et, false), "fromPrimitiveArray", inputObject :: Nil) +*/ case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => NewInstance( classOf[GenericArrayData], @@ -208,18 +211,8 @@ object RowEncoder { // as java.lang.Object. case _: DecimalType => ObjectType(classOf[java.lang.Object]) // In order to support both Array and Seq in external row, we make this as java.lang.Object. - case a @ ArrayType(et, cn) if !cn => - et match { - case BooleanType => ObjectType(classOf[Array[Boolean]]) - case ByteType => ObjectType(classOf[Array[Byte]]) - case ShortType => ObjectType(classOf[Array[Short]]) - case IntegerType => ObjectType(classOf[Array[Int]]) - case LongType => ObjectType(classOf[Array[Long]]) - case FloatType => ObjectType(classOf[Array[Float]]) - case DoubleType => ObjectType(classOf[Array[Double]]) - case _ => ObjectType(classOf[java.lang.Object]) - } - case _: ArrayType => ObjectType(classOf[java.lang.Object]) + case a @ ArrayType(et, cn) => + ObjectType(classOf[java.lang.Object]) case _ => externalDataTypeFor(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a33e53a3408f..228285de9a77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -994,11 +994,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case a @ ArrayType(et, cn) => - if (!cn && ctx.isPrimitiveType(et)) { - s"$obj instanceof ${ctx.javaType(et)}[]" - } else { - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" - } + s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" case _ => s"$obj instanceof ${ctx.boxedType(dataType)}" } From c5378f95aa8df1eda8d4694bfb14a176f2004477 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 4 Nov 2016 20:09:02 +0900 Subject: [PATCH 09/13] addressed review comments --- .../spark/sql/catalyst/ScalaReflection.scala | 8 +---- .../sql/catalyst/encoders/RowEncoder.scala | 29 ++++--------------- .../expressions/objects/objects.scala | 2 +- .../spark/sql/catalyst/util/ArrayData.scala | 15 +++++++++- 4 files changed, 22 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4ed9f88d0464..7bcaea7ea2f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -444,13 +444,7 @@ object ScalaReflection extends ScalaReflection { case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls - if (cls.isAssignableFrom(classOf[Array[Boolean]]) || - cls.isAssignableFrom(classOf[Array[Byte]]) || - cls.isAssignableFrom(classOf[Array[Short]]) || - cls.isAssignableFrom(classOf[Array[Int]]) || - cls.isAssignableFrom(classOf[Array[Long]]) || - cls.isAssignableFrom(classOf[Array[Float]]) || - cls.isAssignableFrom(classOf[Array[Double]])) { + if (cls.isArray && cls.getComponentType.isPrimitive) { StaticInvoke( classOf[UnsafeArrayData], ArrayType(dt, false), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index df03075207f1..157c03b7482d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.objects._ @@ -122,28 +122,12 @@ object RowEncoder { case t @ ArrayType(et, cn) => val cls = inputObject.dataType.asInstanceOf[ObjectType].cls et match { -/* - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType - if !cn && ( - cls.isAssignableFrom(classOf[Array[Boolean]]) || - cls.isAssignableFrom(classOf[Array[Byte]]) || - cls.isAssignableFrom(classOf[Array[Short]]) || - cls.isAssignableFrom(classOf[Array[Int]]) || - cls.isAssignableFrom(classOf[Array[Long]]) || - cls.isAssignableFrom(classOf[Array[Float]]) || - cls.isAssignableFrom(classOf[Array[Double]])) => - print(s"1@ET: $et, $cn, $cls\n") + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => StaticInvoke( - classOf[UnsafeArrayData], - ArrayType(et, false), - "fromPrimitiveArray", + classOf[ArrayData], + ObjectType(classOf[ArrayData]), + "toArrayData", inputObject :: Nil) -*/ - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = t) case _ => MapObjects( element => serializerFor(ValidateExternalType(element, et), et), inputObject, @@ -211,8 +195,7 @@ object RowEncoder { // as java.lang.Object. case _: DecimalType => ObjectType(classOf[java.lang.Object]) // In order to support both Array and Seq in external row, we make this as java.lang.Object. - case a @ ArrayType(et, cn) => - ObjectType(classOf[java.lang.Object]) + case _: ArrayType => ObjectType(classOf[java.lang.Object]) case _ => externalDataTypeFor(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 228285de9a77..50e2ac3c36d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -993,7 +993,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: DecimalType => Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") - case a @ ArrayType(et, cn) => + case _: ArrayType => s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" case _ => s"$obj instanceof ${ctx.boxedType(dataType)}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index cad4a08b0d83..140e86d670a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,9 +19,22 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} import org.apache.spark.sql.types.DataType +object ArrayData { + def toArrayData(input: Any): ArrayData = input match { + case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a) + case other => new GenericArrayData(other) + } +} + abstract class ArrayData extends SpecializedGetters with Serializable { def numElements(): Int From d9e5b4f9c934d1a409d9a74cb8f77745f8e46f51 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 5 Nov 2016 18:08:27 +0900 Subject: [PATCH 10/13] addressed review comments --- .../sql/catalyst/CatalystTypeConverters.scala | 14 -------------- .../sql/catalyst/CatalystTypeConvertersSuite.scala | 4 ---- 2 files changed, 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9d32df5c133d..5b9161551a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -158,20 +158,6 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { - case a: Array[Boolean] => - UnsafeArrayData.fromPrimitiveArray(a) - case a: Array[Byte] => - UnsafeArrayData.fromPrimitiveArray(a) - case a: Array[Short] => - UnsafeArrayData.fromPrimitiveArray(a) - case a: Array[Int] => - UnsafeArrayData.fromPrimitiveArray(a) - case a: Array[Long] => - UnsafeArrayData.fromPrimitiveArray(a) - case a: Array[Float] => - UnsafeArrayData.fromPrimitiveArray(a) - case a: Array[Double] => - UnsafeArrayData.fromPrimitiveArray(a) case a: Array[_] => new GenericArrayData(a.map(elementConverter.toCatalyst)) case s: Seq[_] => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 257d20980356..f3702ec92b42 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -69,16 +69,12 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray) val intArrayType = ArrayType(IntegerType, false) assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray) - assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray) - == intUnsafeArray) val doubleArray = Array(1.1, 111.1, 11111.1) val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray) val doubleArrayType = ArrayType(DoubleType, false) assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray) === doubleArray) - assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) - == doubleUnsafeArray) } test("An array with null handling") { From d507cfce431bdee672a9dd81c5111d768cb28185 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 Nov 2016 05:16:50 +0900 Subject: [PATCH 11/13] addressed comments --- .../sql/catalyst/encoders/RowEncoder.scala | 3 +-- .../sql/catalyst/encoders/RowEncoderSuite.scala | 2 +- .../benchmark/PrimitiveArrayBenchmark.scala | 17 +++++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) mode change 100755 => 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 157c03b7482d..e95e97b9dc6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -120,12 +120,11 @@ object RowEncoder { inputObject :: Nil) case t @ ArrayType(et, cn) => - val cls = inputObject.dataType.asInstanceOf[ObjectType].cls et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => StaticInvoke( classOf[ArrayData], - ObjectType(classOf[ArrayData]), + t, "toArrayData", inputObject :: Nil) case _ => MapObjects( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 9224373d935f..1a5569a77dc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -191,7 +191,7 @@ class RowEncoderSuite extends SparkFunSuite { assert(encoder.serializer.head.nullable == false) } - test("RowEncoder should support a primitive array") { + test("RowEncoder should support primitive arrays") { val schema = new StructType() .add("booleanPrimitiveArray", ArrayType(BooleanType, false)) .add("bytePrimitiveArray", ArrayType(ByteType, false)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala old mode 100755 new mode 100644 index b630871373b4..e7c8f2717fd7 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.Benchmark * Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array * To run this: * 1. replace ignore(...) with test(...) - * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayDataBenchmark" + * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayBenchmark" * * Benchmarks in this file are skipped in normal builds. */ @@ -41,22 +41,23 @@ class PrimitiveArrayBenchmark extends BenchmarkBase { val sc = sparkSession.sparkContext val primitiveIntArray = Array.fill[Int](count)(65535) val dsInt = sc.parallelize(Seq(primitiveIntArray), 1).toDS - dsInt.count + dsInt.count // force to build dataset val intArray = { i: Int => var n = 0 + var len = 0 while (n < iters) { - dsInt.map(e => e).collect + len += dsInt.map(e => e).queryExecution.toRdd.collect.length n += 1 } } val primitiveDoubleArray = Array.fill[Double](count)(65535.0) val dsDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDS - dsDouble.count + dsDouble.count // force to build dataset val doubleArray = { i: Int => var n = 0 - var sum = 0L + var len = 0 while (n < iters) { - dsDouble.map(e => e).collect + len += dsDouble.map(e => e).queryExecution.toRdd.collect.length n += 1 } } @@ -70,8 +71,8 @@ class PrimitiveArrayBenchmark extends BenchmarkBase { Intel Xeon E3-12xx v2 (Ivy Bridge) Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Int 280 / 302 30.0 33.4 1.0X - Double 503 / 519 16.7 60.0 0.6X + Int 352 / 401 23.8 42.0 1.0X + Double 821 / 885 10.2 97.9 0.4X */ } From c2173950f399e910eb7717e2ec8f9e895e9cae1c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 Nov 2016 23:25:14 +0900 Subject: [PATCH 12/13] addressed review comments --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8b0554d5f0ae..2fe1337584ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -988,7 +988,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } - test("array") { + test("identity map for primitive array") { val arrayByte = Array(1.toByte, 2.toByte, 3.toByte) val arrayInt = Array(1, 2, 3) val arrayLong = Array(1.toLong, 2.toLong, 3.toLong) From 4c679b59b34ea01474baa5d4ac4ce89133a1720d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 Nov 2016 23:26:57 +0900 Subject: [PATCH 13/13] addressed review comments --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2fe1337584ac..013d1d1fd220 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -988,7 +988,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } - test("identity map for primitive array") { + test("identity map for primitive arrays") { val arrayByte = Array(1.toByte, 2.toByte, 3.toByte) val arrayInt = Array(1, 2, 3) val arrayLong = Array(1.toLong, 2.toLong, 3.toLong)