From 11357737ad58a2a6c1ea2e17026669fc138f556c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 15 Jul 2016 22:06:41 +0800 Subject: [PATCH 1/3] Support partial aggregation for reduceGroups. --- .../spark/sql/KeyValueGroupedDataset.scala | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a6867a67eeade..0ff0c4002ce91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ou import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.Aggregator /** * :: Experimental :: @@ -177,10 +178,33 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) + val encoder = encoderFor[V] + val intEncoder: ExpressionEncoder[Int] = ExpressionEncoder() + val aggregator: TypedColumn[V, V] = new Aggregator[V, (Int, V), V] { + def bufferEncoder: Encoder[(Int, V)] = ExpressionEncoder.tuple(intEncoder, encoder) + def outputEncoder: Encoder[V] = encoder - implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc) - flatMapGroups(func) + def zero: (Int, V) = (0, null.asInstanceOf[V]) + def reduce(reducedValue: (Int, V), value: V): (Int, V) = { + if (reducedValue._1 == 0) { + (1, value) + } else { + (1, f(reducedValue._2, value)) + } + } + def merge(buf1: (Int, V), buf2: (Int, V)): (Int, V) = { + if (buf1._1 == 0) { + buf2 + } else if (buf2._2 == 0) { + buf1 + } else { + (1, f(buf1._2, buf2._2)) + } + } + def finish(result: (Int, V)): V = result._2 + }.toColumn + + agg(aggregator) } /** From 7e8d8c116552642573cc89bd11fc2e82f2a0f82a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 Jul 2016 16:45:32 +0800 Subject: [PATCH 2/3] Add ReduceAggregator. --- .../spark/sql/KeyValueGroupedDataset.scala | 27 +----- .../sql/expressions/ReduceAggregator.scala | 87 +++++++++++++++++++ .../expressions/ReduceAggregatorSuite.scala | 62 +++++++++++++ 3 files changed, 151 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 0ff0c4002ce91..7e600286e9551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ou import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.ReduceAggregator /** * :: Experimental :: @@ -179,30 +179,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { val encoder = encoderFor[V] - val intEncoder: ExpressionEncoder[Int] = ExpressionEncoder() - val aggregator: TypedColumn[V, V] = new Aggregator[V, (Int, V), V] { - def bufferEncoder: Encoder[(Int, V)] = ExpressionEncoder.tuple(intEncoder, encoder) - def outputEncoder: Encoder[V] = encoder - - def zero: (Int, V) = (0, null.asInstanceOf[V]) - def reduce(reducedValue: (Int, V), value: V): (Int, V) = { - if (reducedValue._1 == 0) { - (1, value) - } else { - (1, f(reducedValue._2, value)) - } - } - def merge(buf1: (Int, V), buf2: (Int, V)): (Int, V) = { - if (buf1._1 == 0) { - buf2 - } else if (buf2._2 == 0) { - buf1 - } else { - (1, f(buf1._2, buf2._2)) - } - } - def finish(result: (Int, V)): V = result._2 - }.toColumn + val aggregator: TypedColumn[V, V] = new ReduceAggregator(f, encoder).toColumn agg(aggregator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala new file mode 100644 index 0000000000000..0f013d49a2cab --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + +/** + * :: Experimental :: + * A generic class for reduce aggregations, which accepts a reduce function that can be used to take + * all of the elements of a group and reduce them to a single value. + * + * @tparam T The input and output type for the reduce function. + * @param func The reduce aggregation function. + * @param encoder The encoder for the input and output type of the reduce function. + * @since 2.1.0 + */ +@Experimental +private[sql] class ReduceAggregator[T](func: (T, T) => T, encoder: ExpressionEncoder[T]) + extends Aggregator[T, (Boolean, T), T] { + + /** + * A zero value for this aggregation. It is represented as a Tuple2. The first element of the + * tuple is a false boolean value indicating the buffer is not initialized. The second element + * is initialized as a null value. + * @since 2.1.0 + */ + override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) + + override def bufferEncoder: Encoder[(Boolean, T)] = + ExpressionEncoder.tuple(ExpressionEncoder[Boolean](), encoder) + + override def outputEncoder: Encoder[T] = encoder + + /** + * Combine two values to produce a new value. If the buffer `b` is not initialized, it simply + * takes the value of `a` and set the initialization flag to `true`. + * @since 2.1.0 + */ + override def reduce(b: (Boolean, T), a: T): (Boolean, T) = { + if (b._1) { + (true, func(b._2, a)) + } else { + (true, a) + } + } + + /** + * Merge two intermediate values. As it is possibly that the buffer is just the `zero` value + * coming from empty partition, it checks if the buffers are initialized, and only performs + * merging when they are initialized both. + * @since 2.1.0 + */ + override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = { + if (!b1._1) { + b2 + } else if (!b2._1) { + b1 + } else { + (true, func(b1._2, b2._2)) + } + } + + /** + * Transform the output of the reduction. Simply output the value in the buffer. + * @since 2.1.0 + */ + override def finish(reduction: (Boolean, T)): T = { + reduction._2 + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala new file mode 100644 index 0000000000000..d2ce2f82c3053 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.ReduceAggregator + +class ReduceAggregatorSuite extends SparkFunSuite { + test("zero value") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) + assert(aggregator.zero == (false, null)) + } + + test("reduce, merge and finish") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) + + val firstReduce = aggregator.reduce(aggregator.zero, 1) + assert(firstReduce == (true, 1)) + + val secondReduce = aggregator.reduce(firstReduce, 2) + assert(secondReduce == (true, 3)) + + val thirdReduce = aggregator.reduce(secondReduce, 3) + assert(thirdReduce == (true, 6)) + + val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce) + assert(mergeWithZero1 == (true, 1)) + + val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero) + assert(mergeWithZero2 == (true, 3)) + + val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce) + assert(mergeTwoReduced == (true, 4)) + + assert(aggregator.finish(firstReduce)== 1) + assert(aggregator.finish(secondReduce) == 3) + assert(aggregator.finish(thirdReduce) == 6) + assert(aggregator.finish(mergeWithZero1) == 1) + assert(aggregator.finish(mergeWithZero2) == 3) + assert(aggregator.finish(mergeTwoReduced) == 4) + } +} From 6032325ccfa344d1d57a40e7c91d70279c744077 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 20 Jul 2016 16:59:07 +0800 Subject: [PATCH 3/3] For comment. --- .../spark/sql/DatasetAggregatorSuite.scala | 40 +++++++++++- .../expressions/ReduceAggregatorSuite.scala | 62 ------------------- 2 files changed, 39 insertions(+), 63 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index ddc4dcd2395b2..535a6c7d21ed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.language.postfixOps import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -314,4 +314,42 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] assert(ds3.select(NameAgg.toColumn).schema.head.nullable === true) } + + test("ReduceAggregator: zero value") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) + assert(aggregator.zero == (false, null)) + } + + test("ReduceAggregator: reduce, merge and finish") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) + + val firstReduce = aggregator.reduce(aggregator.zero, 1) + assert(firstReduce == (true, 1)) + + val secondReduce = aggregator.reduce(firstReduce, 2) + assert(secondReduce == (true, 3)) + + val thirdReduce = aggregator.reduce(secondReduce, 3) + assert(thirdReduce == (true, 6)) + + val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce) + assert(mergeWithZero1 == (true, 1)) + + val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero) + assert(mergeWithZero2 == (true, 3)) + + val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce) + assert(mergeTwoReduced == (true, 4)) + + assert(aggregator.finish(firstReduce)== 1) + assert(aggregator.finish(secondReduce) == 3) + assert(aggregator.finish(thirdReduce) == 6) + assert(aggregator.finish(mergeWithZero1) == 1) + assert(aggregator.finish(mergeWithZero2) == 3) + assert(aggregator.finish(mergeTwoReduced) == 4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala deleted file mode 100644 index d2ce2f82c3053..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.expressions.ReduceAggregator - -class ReduceAggregatorSuite extends SparkFunSuite { - test("zero value") { - val encoder: ExpressionEncoder[Int] = ExpressionEncoder() - val func = (v1: Int, v2: Int) => v1 + v2 - val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) - assert(aggregator.zero == (false, null)) - } - - test("reduce, merge and finish") { - val encoder: ExpressionEncoder[Int] = ExpressionEncoder() - val func = (v1: Int, v2: Int) => v1 + v2 - val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func, encoder) - - val firstReduce = aggregator.reduce(aggregator.zero, 1) - assert(firstReduce == (true, 1)) - - val secondReduce = aggregator.reduce(firstReduce, 2) - assert(secondReduce == (true, 3)) - - val thirdReduce = aggregator.reduce(secondReduce, 3) - assert(thirdReduce == (true, 6)) - - val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce) - assert(mergeWithZero1 == (true, 1)) - - val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero) - assert(mergeWithZero2 == (true, 3)) - - val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce) - assert(mergeTwoReduced == (true, 4)) - - assert(aggregator.finish(firstReduce)== 1) - assert(aggregator.finish(secondReduce) == 3) - assert(aggregator.finish(thirdReduce) == 6) - assert(aggregator.finish(mergeWithZero1) == 1) - assert(aggregator.finish(mergeWithZero2) == 3) - assert(aggregator.finish(mergeTwoReduced) == 4) - } -}