From 64e18a78a4184bcefaf000e8107f2f3ac5a5d659 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 14 May 2015 17:29:31 +0800 Subject: [PATCH 01/16] Add Window Function support for DataFrame --- .../scala/org/apache/spark/sql/Column.scala | 10 + .../spark/sql/WindowFunctionDefinition.scala | 257 ++++++++++++++++++ .../sql/hive/HiveDataFrameWindowSuite.scala | 165 +++++++++++ 3 files changed, 432 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index dc0aeea7c4ae..f7f5b956b546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -889,6 +889,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + /** + * Create a new [[WindowFunctionDefinition]] bundled with this column(expression). + * {{{ + * df.select(avg($"value").over...) + * }}} + * + * @group expr_ops + */ + def over: WindowFunctionDefinition = new WindowFunctionDefinition(this) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala new file mode 100644 index 000000000000..bfabb5e2b03e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * A set of methods for window function definition for aggregate expressions. + * For example: + * {{{ + * df.select( + * avg("value") + * .over + * .partitionBy("k1") + * .orderBy("k2", "k3") + * .row + * .following(1) + * .toColumn.as("avg_value"), + * max("value") + * .over + * .partitionBy("k2") + * .orderBy("k3") + * .between + * .preceding(4) + * .following(3) + * .toColumn.as("max_value")) + * }}} + * + * + */ +@Experimental +class WindowFunctionDefinition protected[sql]( + column: Column, + partitionSpec: Seq[Expression] = Nil, + orderSpec: Seq[SortOrder] = Nil, + frame: WindowFrame = UnspecifiedFrame) { + + /** + * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. + * {{{ + * // The following 2 are equivalent + * df.over.partitionBy("k1", "k2", ...) + * df.over.partitionBy($"K1", $"k2", ...) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowFunctionDefinition = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. For example: + * {{{ + * df.over.partitionBy($"col1", $"col2") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowFunctionDefinition = { + new WindowFunctionDefinition(column, cols.map(_.expr), orderSpec, frame) + } + + /** + * Returns a new [[WindowFunctionDefinition]] sorted by the specified column within + * the partition. + * {{{ + * // The following 2 are equivalent + * df.over.partitionBy("k1").orderBy("k2", "k3") + * df.over.partitionBy("k1").orderBy($"k2", $"k3") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowFunctionDefinition = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Returns a new [[WindowFunctionDefinition]] sorted by the specified column within + * the partition. For example + * {{{ + * df.over.partitionBy("k1").orderBy($"k2", $"k3") + * }}} + * @group window_funcs + */ + def orderBy(cols: Column*): WindowFunctionDefinition = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new WindowFunctionDefinition(column, partitionSpec, sortOrder, frame) + } + + /** + * Returns a new ranged [[WindowFunctionDefinition]]. For example: + * {{{ + * df.over.partitionBy("k1").orderBy($"k2", $"k3").between + * }}} + * @group window_funcs + */ + def between: WindowFunctionDefinition = { + new WindowFunctionDefinition(column, partitionSpec, orderSpec, + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, UnboundedFollowing)) + } + + /** + * Returns a new [[WindowFunctionDefinition]], with fixed number of records + * from/to CURRENT ROW. For example: + * {{{ + * df.over.partitionBy("k1").orderBy($"k2", $"k3").row + * }}} + * @group window_funcs + */ + def rows: WindowFunctionDefinition = { + new WindowFunctionDefinition(column, partitionSpec, orderSpec, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)) + } + + /** + * Returns a new [[WindowFunctionDefinition]], with range of preceding position specified. + * For a Ranged [[WindowFunctionDefinition]], the range is [CURRENT_ROW - n, unspecified] + * For a Fixed Row [[WindowFunctionDefinition]], the range as [CURRENT_ROW - n, CURRENT_ROW]. + * For example: + * {{{ + * // The range is [CURRENT_ROW - 1, CURRENT_ROW] + * df.over.partitionBy("k1").orderBy($"k2", $"k3").row.preceding(1) + * // The range [CURRENT_ROW - 1, previous upper bound] + * df.over.partitionBy("k1").orderBy($"k2", $"k3").between.preceding(1) + * }}} + * If n equals 0, it will be considered as CURRENT_ROW + * @group window_funcs + */ + def preceding(n: Int): WindowFunctionDefinition = { + val newFrame = frame match { + case f @ SpecifiedWindowFrame(RowFrame, _, _) if n == 0 => // TODO should we need this? + f.copy(frameStart = CurrentRow, frameEnd = CurrentRow) + case f @ SpecifiedWindowFrame(RowFrame, _, _) => + f.copy(frameStart = ValuePreceding(n), frameEnd = CurrentRow) + case f @ SpecifiedWindowFrame(RangeFrame, _, _) if n == 0 => f.copy(frameStart = CurrentRow) + case f @ SpecifiedWindowFrame(RangeFrame, _, _) => f.copy(frameStart = ValuePreceding(n)) + case f => throw new UnsupportedOperationException(s"preceding on $f") + } + new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame) + } + + /** + * Returns a new [[WindowFunctionDefinition]], with range of following position specified. + * For a Ranged [[WindowFunctionDefinition]], the range is [unspecified, CURRENT_ROW + n] + * For a Fixed Row [[WindowFunctionDefinition]], the range as [CURRENT_ROW, CURRENT_ROW + n]. + * For example: + * {{{ + * // The range is [CURRENT_ROW, CURRENT_ROW + 1] + * df.over.partitionBy("k1").orderBy($"k2", $"k3").row.following(1) + * // The range [previous lower bound, CURRENT_ROW + 1] + * df.over.partitionBy("k1").orderBy($"k2", $"k3").between.following(1) + * }}} + * If n equals 0, it will be considered as CURRENT_ROW + * @group window_funcs + */ + def following(n: Int): WindowFunctionDefinition = { + val newFrame = frame match { + case f @ SpecifiedWindowFrame(RowFrame, _, _) if n == 0 => // TODO should we need this? + f.copy(frameStart = CurrentRow, frameEnd = CurrentRow) + case f @ SpecifiedWindowFrame(RowFrame, _, _) => + f.copy(frameStart = CurrentRow, frameEnd = ValueFollowing(n)) + case f @ SpecifiedWindowFrame(RangeFrame, _, _) if n == 0 => f.copy(frameEnd = CurrentRow) + case f @ SpecifiedWindowFrame(RangeFrame, _, _) => f.copy(frameEnd = ValuePreceding(n)) + case f => throw new UnsupportedOperationException(s"following on $f") + } + new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame) + } + + /** + * Convert the window definition into a new Column. + * Currently, only aggregate expressions are supported for window function. For Example: + * {{{ + * df.select( + * avg("value") + * .over + * .partitionBy("k1") + * .orderBy($"k2", $"k3") + * .row + * .following(1) + * .toColumn.as("avg_value"), + * max("value") + * .over + * .partitionBy("k2") + * .orderBy("k3") + * .between + * .preceding(4) + * .following(3) + * .toColumn.as("max_value")) + * }}} + * @group window_funcs + */ + def toColumn: Column = { + val windowExpr = column.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case aggr: AggregateExpression => + throw new UnsupportedOperationException( + """Only support Aggregate Functions: + | avg, sum, count, first, last, min, max for now""".stripMargin) + case x => + throw new UnsupportedOperationException(s"We don't support $x in window operation.") + } + new Column(windowExpr) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala new file mode 100644 index 000000000000..62c9ed95cb5b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{AnalysisException, Row, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +class HiveDataFrameWindowSuite extends QueryTest { + + test("aggregation in a Row window") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + avg("key").over + .partitionBy($"key") + .orderBy($"value") + .rows + .preceding(1) + .toColumn), + Row(1.0) :: Row(2.0) :: Nil) + } + + test("aggregation in a Range window") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + avg("key").over + .partitionBy($"value") + .orderBy($"key") + .between + .preceding(1) + .following(1) + .toColumn), + Row(1.0) :: Row(2.0) :: Nil) + } + + test("multiple aggregate function in window") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + avg("key").over + .partitionBy($"value") + .orderBy($"key") + .rows + .preceding(1).toColumn, + sum("key").over + .partitionBy($"value") + .orderBy($"key") + .between + .preceding(1) + .following(1) + .toColumn), + Row(1, 1.0) :: Row(2, 2.0) :: Nil) + } + + test("Window function in Unspecified Window") { + val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + + checkAnswer( + df.select( + $"key", + first("value").over + .partitionBy($"key") + .toColumn), + Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil) + } + + test("Window function in Unspecified Window #2") { + val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + + checkAnswer( + df.select( + $"key", + first("value").over + .partitionBy($"key") + .orderBy($"value") + .toColumn), + Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil) + } + + test("Aggregate function in Range Window") { + val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + + checkAnswer( + df.select( + $"key", + first("value").over + .partitionBy($"value") + .orderBy($"key") + .between + .preceding(1) + .following(1) + .toColumn), + Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) + } + + test("Aggregate function in Row preceding Window") { + val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + checkAnswer( + df.select( + $"key", + first("value").over + .partitionBy($"value") + .orderBy($"key") + .rows + .preceding(1) + .toColumn), + Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) + } + + test("Aggregate function in Row following Window") { + val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + checkAnswer( + df.select( + $"key", + last("value").over + .partitionBy($"value") + .orderBy($"key") + .rows + .following(1) + .toColumn), + Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) + } + + test("Multiple aggregate functions") { + val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + checkAnswer( + df.select( + $"key", + last("value").over + .partitionBy($"value") + .orderBy($"key") + .rows + .following(1) + .toColumn + .equalTo("2") + .as("last_v"), + avg("key") + .over + .partitionBy("value") + .orderBy("key") + .between + .preceding(2) + .following(1) + .toColumn.as("avg_key") + ), + Row(1, false, 1.0) :: Row(2, false, 2.0) :: Row(2, true, 2.0) :: Nil) + } +} From 964c013494a31ce5f06b111541dbde71ab6239c2 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 15 May 2015 22:34:19 +0800 Subject: [PATCH 02/16] add more unit tests and window functions --- .../scala/org/apache/spark/sql/Column.scala | 16 ++- .../spark/sql/WindowFunctionDefinition.scala | 9 +- .../org/apache/spark/sql/functions.scala | 128 ++++++++++++++++++ .../sql/hive/HiveDataFrameWindowSuite.scala | 71 +++++++++- 4 files changed, 216 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f7f5b956b546..0132c93ba2c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -890,7 +890,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) /** - * Create a new [[WindowFunctionDefinition]] bundled with this column(expression). + * Create a new [[WindowFunctionDefinition]] bundled with this column. * {{{ * df.select(avg($"value").over...) * }}} @@ -899,6 +899,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def over: WindowFunctionDefinition = new WindowFunctionDefinition(this) + /** + * Reuse an existed [[WindowFunctionDefinition]] and bundled with this column. + * {{{ + * val w = over.partitionBy("name").orderBy("id") + * df.select( + * sum("price").over(w).between.preceding(2), + * avg("price").over(w).between.preceding(4) + * ) + * }}} + * + * @group expr_ops + */ + def over(w: WindowFunctionDefinition) = w.newColumn(this) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala index bfabb5e2b03e..7bdd0daebf59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -49,11 +49,15 @@ import org.apache.spark.sql.catalyst.expressions._ */ @Experimental class WindowFunctionDefinition protected[sql]( - column: Column, + column: Column = null, partitionSpec: Seq[Expression] = Nil, orderSpec: Seq[SortOrder] = Nil, frame: WindowFrame = UnspecifiedFrame) { + private[sql] def newColumn(c: Column): WindowFunctionDefinition = { + new WindowFunctionDefinition(c, partitionSpec, orderSpec, frame) + } + /** * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. * {{{ @@ -218,6 +222,9 @@ class WindowFunctionDefinition protected[sql]( * @group window_funcs */ def toColumn: Column = { + if (column == null) { + throw new AnalysisException("Window didn't bind with expression") + } val windowExpr = column.expr match { case Average(child) => WindowExpression( UnresolvedWindowFunction("avg", child :: Nil), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6640631cf071..5e6947ae4667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -320,6 +320,132 @@ object functions { */ def max(columnName: String): Column = max(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Window functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Window function: returns the lag value of current row of the expression, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(columnName: String): Column = { + lag(columnName, 1) + } + + /** + * Window function: returns the lag value of current row of the column, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(e: Column): Column = { + lag(e, 1) + } + + /** + * Window function: returns the lag values of current row of the expression, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(e: Column, count: Int): Column = { + lag(e, count, null) + } + + /** + * Window function: returns the lag values of current row of the column, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(columnName: String, count: Int): Column = { + lag(columnName, count, null) + } + + /** + * Window function: returns the lag values of current row of the column, + * given default value when the current row extends before the beginning + * of the window. + * + * @group window_funcs + */ + def lag(columnName: String, count: Int, defaultValue: Any): Column = { + lag(Column(columnName), count, defaultValue) + } + + /** + * Window function: returns the lag values of current row of the expression, + * given default value when the current row extends before the beginning + * of the window. + * + * @group window_funcs + */ + def lag(e: Column, count: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the lead value of current row of the column, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String): Column = { + lead(columnName, 1) + } + + /** + * Window function: returns the lead value of current row of the expression, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column): Column = { + lead(e, 1) + } + + /** + * Window function: returns the lead values of current row of the column, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String, count: Int): Column = { + lead(columnName, count, null) + } + + /** + * Window function: returns the lead values of current row of the expression, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column, count: Int): Column = { + lead(e, count, null) + } + + /** + * Window function: returns the lead values of current row of the column, + * given default value when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String, count: Int, defaultValue: Any): Column = { + lead(Column(columnName), count, defaultValue) + } + + /** + * Window function: returns the lead values of current row of the expression, + * given default value when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column, count: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1393,4 +1519,6 @@ object functions { UnresolvedFunction(udfName, cols.map(_.expr)) } + def over: WindowFunctionDefinition = new WindowFunctionDefinition() + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 62c9ed95cb5b..c116c5fa39ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -23,8 +23,67 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ class HiveDataFrameWindowSuite extends QueryTest { + test("reuse window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = over.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key").over(w).toColumn, + lead("value").over(w).toColumn), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead in window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value").over + .partitionBy($"key") + .orderBy($"value") + .toColumn), + Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) + } + + test("lag in window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value").over + .partitionBy($"key") + .orderBy($"value") + .toColumn), + Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) + } + + test("lead in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value", 2, "n/a").over + .partitionBy("key") + .orderBy("value") + .toColumn), + Row("1") :: Row("1") :: Row("2") :: Row("n/a") + :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) + } + + test("lag in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value", 2, "n/a").over + .partitionBy($"key") + .orderBy($"value") + .toColumn), + Row("1") :: Row("1") :: Row("2") :: Row("n/a") + :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) + } + test("aggregation in a Row window") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( avg("key").over @@ -33,11 +92,11 @@ class HiveDataFrameWindowSuite extends QueryTest { .rows .preceding(1) .toColumn), - Row(1.0) :: Row(2.0) :: Nil) + Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) } test("aggregation in a Range window") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( avg("key").over @@ -47,11 +106,11 @@ class HiveDataFrameWindowSuite extends QueryTest { .preceding(1) .following(1) .toColumn), - Row(1.0) :: Row(2.0) :: Nil) + Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) } test("multiple aggregate function in window") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( avg("key").over @@ -66,7 +125,7 @@ class HiveDataFrameWindowSuite extends QueryTest { .preceding(1) .following(1) .toColumn), - Row(1, 1.0) :: Row(2, 2.0) :: Nil) + Row(1.0, 1) :: Row(1.0, 2) :: Row(2.0, 2) :: Row(2.0, 4) :: Nil) } test("Window function in Unspecified Window") { From 53f89f28ec223001069fc728ed7ef98dcce10b03 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 15 May 2015 22:51:01 +0800 Subject: [PATCH 03/16] remove the over from the functions.scala --- .../org/apache/spark/sql/functions.scala | 57 ++++++++++++++++++- .../sql/hive/HiveDataFrameWindowSuite.scala | 15 ++++- 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5e6947ae4667..c21cbd95ed67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -446,6 +446,61 @@ object functions { UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) } + /** + * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. + * For example: + * {{{ + * // The following 2 are equivalent + * partitionBy("k1", "k2").orderBy("k3") + * partitionBy($"K1", $"k2").orderBy($"k3") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowFunctionDefinition = { + new WindowFunctionDefinition().partitionBy(colName, colNames: _*) + } + + /** + * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. + * For example: + * {{{ + * partitionBy($"col1", $"col2").orderBy("value") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowFunctionDefinition = { + new WindowFunctionDefinition().partitionBy(cols: _*) + } + + /** + * Create a new [[WindowFunctionDefinition]] sorted by the specified columns. + * For example: + * {{{ + * // The following 2 are equivalent + * orderBy("k2", "k3").partitionBy("k1") + * orderBy($"k2", $"k3").partitionBy("k1") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowFunctionDefinition = { + new WindowFunctionDefinition().orderBy(colName, colNames: _*) + } + + /** + * Returns a new [[WindowFunctionDefinition]] sorted by the specified columns. + * For example + * {{{ + * val w = orderBy($"k2", $"k3").partitionBy("k1") + * }}} + * @group window_funcs + */ + def orderBy(cols: Column*): WindowFunctionDefinition = { + new WindowFunctionDefinition().orderBy(cols: _*) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1519,6 +1574,4 @@ object functions { UnresolvedFunction(udfName, cols.map(_.expr)) } - def over: WindowFunctionDefinition = new WindowFunctionDefinition() - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index c116c5fa39ec..43fadc525f53 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -23,9 +23,20 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ class HiveDataFrameWindowSuite extends QueryTest { - test("reuse window") { + test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = over.partitionBy("key").orderBy("value") + val w = partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key").over(w).toColumn, + lead("value").over(w).toColumn), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = orderBy("value").partitionBy("key") checkAnswer( df.select( From 1d918655259d51c590acef8fc97ce4f17a02aadc Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 15 May 2015 22:56:21 +0800 Subject: [PATCH 04/16] style issue --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 0132c93ba2c3..803da52069f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -911,7 +911,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def over(w: WindowFunctionDefinition) = w.newColumn(this) + def over(w: WindowFunctionDefinition): WindowFunctionDefinition = w.newColumn(this) } From 28222ed67d92d1b3c4bf5003bd623fb48805c62c Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 17 May 2015 22:26:28 +0800 Subject: [PATCH 05/16] fix bug of range/row Frame --- .../scala/org/apache/spark/sql/Column.scala | 21 +- .../spark/sql/WindowFunctionDefinition.scala | 160 ++++++++++---- .../sql/hive/HiveDataFrameWindowSuite.scala | 204 +++++++++++------- 3 files changed, 259 insertions(+), 126 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 803da52069f4..f096b17106dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -890,28 +890,19 @@ class Column(protected[sql] val expr: Expression) extends Logging { def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) /** - * Create a new [[WindowFunctionDefinition]] bundled with this column. + * Define a [[WindowFunctionDefinition]] column. * {{{ - * df.select(avg($"value").over...) - * }}} - * - * @group expr_ops - */ - def over: WindowFunctionDefinition = new WindowFunctionDefinition(this) - - /** - * Reuse an existed [[WindowFunctionDefinition]] and bundled with this column. - * {{{ - * val w = over.partitionBy("name").orderBy("id") + * val w = partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w).between.preceding(2), - * avg("price").over(w).between.preceding(4) + * sum("price").over(w).range.preceding(2), + * avg("price").over(w).range.preceding(4), + * avg("price").over(partitionBy("name").orderBy("id).range.preceding(1)) * ) * }}} * * @group expr_ops */ - def over(w: WindowFunctionDefinition): WindowFunctionDefinition = w.newColumn(this) + def over(w: WindowFunctionDefinition): Column = w.newColumn(this).toColumn } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala index 7bdd0daebf59..fce76755a8ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -52,10 +52,11 @@ class WindowFunctionDefinition protected[sql]( column: Column = null, partitionSpec: Seq[Expression] = Nil, orderSpec: Seq[SortOrder] = Nil, - frame: WindowFrame = UnspecifiedFrame) { + frame: WindowFrame = UnspecifiedFrame, + bindLower: Boolean = true) { private[sql] def newColumn(c: Column): WindowFunctionDefinition = { - new WindowFunctionDefinition(c, partitionSpec, orderSpec, frame) + new WindowFunctionDefinition(c, partitionSpec, orderSpec, frame, bindLower) } /** @@ -120,20 +121,48 @@ class WindowFunctionDefinition protected[sql]( } /** - * Returns a new ranged [[WindowFunctionDefinition]]. For example: + * Returns the current [[WindowFunctionDefinition]]. This is a dummy function, + * which makes the usage more like the SQL. + * For example: * {{{ - * df.over.partitionBy("k1").orderBy($"k2", $"k3").between + * df.over.partitionBy("k1").orderBy($"k2", $"k3").range.between * }}} * @group window_funcs */ def between: WindowFunctionDefinition = { + assert(this.frame.isInstanceOf[SpecifiedWindowFrame], "Should be a WindowFrame.") + new WindowFunctionDefinition(column, partitionSpec, orderSpec, frame, true) + } + + /** + * Returns a new [[WindowFunctionDefinition]] indicate that we need to specify the + * upper bound. + * For example: + * {{{ + * df.over.partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(3).and + * }}} + * @group window_funcs + */ + def and: WindowFunctionDefinition = { + new WindowFunctionDefinition(column, partitionSpec, orderSpec, frame, false) + } + + /** + * Returns a new Ranged [[WindowFunctionDefinition]]. + * For example: + * {{{ + * df.over.partitionBy("k1").orderBy($"k2", $"k3").range.between + * }}} + * @group window_funcs + */ + def range: WindowFunctionDefinition = { new WindowFunctionDefinition(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, UnboundedFollowing)) } /** - * Returns a new [[WindowFunctionDefinition]], with fixed number of records - * from/to CURRENT ROW. For example: + * Returns a new [[WindowFunctionDefinition]], with fixed number of records. + * For example: * {{{ * df.over.partitionBy("k1").orderBy($"k2", $"k3").row * }}} @@ -145,57 +174,114 @@ class WindowFunctionDefinition protected[sql]( } /** - * Returns a new [[WindowFunctionDefinition]], with range of preceding position specified. - * For a Ranged [[WindowFunctionDefinition]], the range is [CURRENT_ROW - n, unspecified] - * For a Fixed Row [[WindowFunctionDefinition]], the range as [CURRENT_ROW - n, CURRENT_ROW]. + * Returns a new [[WindowFunctionDefinition]], with position specified preceding of CURRENT_ROW. + * It can be either Lower or Upper Bound position, depends on whether the `and` method called. * For example: * {{{ - * // The range is [CURRENT_ROW - 1, CURRENT_ROW] - * df.over.partitionBy("k1").orderBy($"k2", $"k3").row.preceding(1) - * // The range [CURRENT_ROW - 1, previous upper bound] - * df.over.partitionBy("k1").orderBy($"k2", $"k3").between.preceding(1) + * // [CURRENT_ROW - 1, ~) + * df.over(partitionBy("k1").orderBy("k2").row.preceding(1)) + * // [CURRENT_ROW - 3, CURRENT_ROW - 1] + * df.over(partitionBy("k1").orderBy("k2").row.between.preceding(3).and.preceding(1)) + * // (~, CURRENT_ROW - 1] + * df.over(partitionBy("k1").orderBy("k2").row.between.unboundedPreceding.and.preceding(1)) * }}} - * If n equals 0, it will be considered as CURRENT_ROW * @group window_funcs */ def preceding(n: Int): WindowFunctionDefinition = { + assert(n > 0) val newFrame = frame match { - case f @ SpecifiedWindowFrame(RowFrame, _, _) if n == 0 => // TODO should we need this? - f.copy(frameStart = CurrentRow, frameEnd = CurrentRow) - case f @ SpecifiedWindowFrame(RowFrame, _, _) => - f.copy(frameStart = ValuePreceding(n), frameEnd = CurrentRow) - case f @ SpecifiedWindowFrame(RangeFrame, _, _) if n == 0 => f.copy(frameStart = CurrentRow) - case f @ SpecifiedWindowFrame(RangeFrame, _, _) => f.copy(frameStart = ValuePreceding(n)) + case f: SpecifiedWindowFrame if bindLower => + f.copy(frameStart = ValuePreceding(n)) + case f: SpecifiedWindowFrame => + f.copy(frameEnd = ValuePreceding(n)) case f => throw new UnsupportedOperationException(s"preceding on $f") } - new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame) + new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) + } + + /** + * Returns a new [[WindowFunctionDefinition]], with lower position as unbounded. + * For example: + * {{{ + * // (~, CURRENT_ROW] + * df.over(partitionBy("k1").orderBy("k2").row.between.unboundedPreceding.and.currentRow) + * }}} + * @group window_funcs + */ + def unboundedPreceding(): WindowFunctionDefinition = { + val newFrame = frame match { + case f : SpecifiedWindowFrame => + f.copy(frameStart = UnboundedPreceding) + case f => throw new UnsupportedOperationException(s"unboundedPreceding on $f") + } + new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) + } + + /** + * Returns a new [[WindowFunctionDefinition]], with upper position as unbounded. + * For example: + * {{{ + * // [CURRENT_ROW, ~) + * df.over(partitionBy("k1").orderBy("k2").row.between.currentRow.and.unboundedFollowing) + * }}} + * @group window_funcs + */ + def unboundedFollowing(): WindowFunctionDefinition = { + val newFrame = frame match { + case f : SpecifiedWindowFrame => + f.copy(frameEnd = UnboundedFollowing) + case f => throw new UnsupportedOperationException(s"unboundedFollowing on $f") + } + new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) + } + + /** + * Returns a new [[WindowFunctionDefinition]], with position as CURRENT_ROW. + * It can be either Lower or Upper Bound position, depends on whether the `and` method called. + * For example: + * {{{ + * // [CURRENT_ROW, ~) + * df.over(partitionBy("k1").orderBy("k2").row.between.currentRow.and.unboundedFollowing) + * // [CURRENT_ROW - 3, CURRENT_ROW] + * df.over(partitionBy("k1").orderBy("k2").row.between.preceding(3).and.currentRow) + * }}} + * @group window_funcs + */ + def currentRow(): WindowFunctionDefinition = { + val newFrame = frame match { + case f : SpecifiedWindowFrame if bindLower => + f.copy(frameStart = CurrentRow) + case f : SpecifiedWindowFrame => + f.copy(frameEnd = CurrentRow) + case f => throw new UnsupportedOperationException(s"currentRow on $f") + } + new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) } /** - * Returns a new [[WindowFunctionDefinition]], with range of following position specified. - * For a Ranged [[WindowFunctionDefinition]], the range is [unspecified, CURRENT_ROW + n] - * For a Fixed Row [[WindowFunctionDefinition]], the range as [CURRENT_ROW, CURRENT_ROW + n]. + * Returns a new [[WindowFunctionDefinition]], with position specified following of CURRENT_ROW. + * It can be either Lower or Upper Bound position, depends on whether the `and` method called. * For example: * {{{ - * // The range is [CURRENT_ROW, CURRENT_ROW + 1] - * df.over.partitionBy("k1").orderBy($"k2", $"k3").row.following(1) - * // The range [previous lower bound, CURRENT_ROW + 1] - * df.over.partitionBy("k1").orderBy($"k2", $"k3").between.following(1) + * // [CURRENT_ROW + 1, ~) + * df.over(partitionBy("k1").orderBy("k2").row.following(1)) + * // [CURRENT_ROW + 1, CURRENT_ROW + 3] + * df.over(partitionBy("k1").orderBy("k2").row.between.following(1).and.following(3)) + * // [CURRENT_ROW + 1, ~) + * df.over(partitionBy("k1").orderBy("k2").row.between.following(1).and.unboundedFollowing) * }}} - * If n equals 0, it will be considered as CURRENT_ROW * @group window_funcs */ def following(n: Int): WindowFunctionDefinition = { + assert(n > 0) val newFrame = frame match { - case f @ SpecifiedWindowFrame(RowFrame, _, _) if n == 0 => // TODO should we need this? - f.copy(frameStart = CurrentRow, frameEnd = CurrentRow) - case f @ SpecifiedWindowFrame(RowFrame, _, _) => - f.copy(frameStart = CurrentRow, frameEnd = ValueFollowing(n)) - case f @ SpecifiedWindowFrame(RangeFrame, _, _) if n == 0 => f.copy(frameEnd = CurrentRow) - case f @ SpecifiedWindowFrame(RangeFrame, _, _) => f.copy(frameEnd = ValuePreceding(n)) + case f: SpecifiedWindowFrame if bindLower => + f.copy(frameStart = ValueFollowing(n)) + case f: SpecifiedWindowFrame => + f.copy(frameEnd = ValueFollowing(n)) case f => throw new UnsupportedOperationException(s"following on $f") } - new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame) + new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) } /** @@ -221,7 +307,7 @@ class WindowFunctionDefinition protected[sql]( * }}} * @group window_funcs */ - def toColumn: Column = { + private[sql] def toColumn: Column = { if (column == null) { throw new AnalysisException("Window didn't bind with expression") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 43fadc525f53..2b49f383b2ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -29,8 +29,8 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( - lead("key").over(w).toColumn, - lead("value").over(w).toColumn), + lead("key").over(w), + lead("value").over(w)), Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } @@ -40,8 +40,8 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( - lead("key").over(w).toColumn, - lead("value").over(w).toColumn), + lead("key").over(w), + lead("value").over(w)), Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } @@ -49,10 +49,9 @@ class HiveDataFrameWindowSuite extends QueryTest { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - lead("value").over - .partitionBy($"key") - .orderBy($"value") - .toColumn), + lead("value").over( + partitionBy($"key") + .orderBy($"value"))), Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) } @@ -60,10 +59,9 @@ class HiveDataFrameWindowSuite extends QueryTest { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - lead("value").over - .partitionBy($"key") - .orderBy($"value") - .toColumn), + lead("value").over( + partitionBy($"key") + .orderBy($"value"))), Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) } @@ -72,10 +70,9 @@ class HiveDataFrameWindowSuite extends QueryTest { (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - lead("value", 2, "n/a").over - .partitionBy("key") - .orderBy("value") - .toColumn), + lead("value", 2, "n/a").over( + partitionBy("key") + .orderBy("value"))), Row("1") :: Row("1") :: Row("2") :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) } @@ -85,24 +82,25 @@ class HiveDataFrameWindowSuite extends QueryTest { (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - lead("value", 2, "n/a").over - .partitionBy($"key") - .orderBy($"value") - .toColumn), + lead("value", 2, "n/a").over( + partitionBy($"key") + .orderBy($"value"))), Row("1") :: Row("1") :: Row("2") :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) } - test("aggregation in a Row window") { + test("aggregation in a row window") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - avg("key").over - .partitionBy($"key") - .orderBy($"value") - .rows - .preceding(1) - .toColumn), + avg("key").over( + partitionBy($"value") + .orderBy($"key") + .rows + .between + .preceding(1) + .and + .following(1))), Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) } @@ -110,13 +108,14 @@ class HiveDataFrameWindowSuite extends QueryTest { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - avg("key").over - .partitionBy($"value") + avg("key").over( + partitionBy($"value") .orderBy($"key") + .range .between .preceding(1) - .following(1) - .toColumn), + .and + .following(1))), Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) } @@ -124,19 +123,20 @@ class HiveDataFrameWindowSuite extends QueryTest { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - avg("key").over - .partitionBy($"value") + avg("key").over( + partitionBy($"value") .orderBy($"key") .rows - .preceding(1).toColumn, - sum("key").over - .partitionBy($"value") + .preceding(1)), + sum("key").over( + partitionBy($"value") .orderBy($"key") + .range .between .preceding(1) - .following(1) - .toColumn), - Row(1.0, 1) :: Row(1.0, 2) :: Row(2.0, 2) :: Row(2.0, 4) :: Nil) + .and + .following(1))), + Row(1.0, 2) :: Row(1.0, 2) :: Row(2.0, 4) :: Row(2.0, 4) :: Nil) } test("Window function in Unspecified Window") { @@ -145,9 +145,8 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - first("value").over - .partitionBy($"key") - .toColumn), + first("value").over( + partitionBy($"key"))), Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil) } @@ -157,10 +156,9 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - first("value").over - .partitionBy($"key") - .orderBy($"value") - .toColumn), + first("value").over( + partitionBy($"key") + .orderBy($"value"))), Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil) } @@ -170,13 +168,14 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - first("value").over - .partitionBy($"value") + first("value").over( + partitionBy($"value") .orderBy($"key") + .range .between .preceding(1) - .following(1) - .toColumn), + .and + .following(1))), Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) } @@ -185,12 +184,11 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - first("value").over - .partitionBy($"value") + first("value").over( + partitionBy($"value") .orderBy($"key") .rows - .preceding(1) - .toColumn), + .preceding(1))), Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) } @@ -199,37 +197,95 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - last("value").over - .partitionBy($"value") + last("value").over( + partitionBy($"value") .orderBy($"key") .rows - .following(1) - .toColumn), + .following(1))), Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) } - test("Multiple aggregate functions") { - val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + test("Multiple aggregate functions in row window") { + val df = Seq((1, "1"), (1, "2"), (3, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + avg("key").over( + partitionBy($"key") + .orderBy($"value") + .rows + .preceding(1)), + avg("key").over( + partitionBy($"key") + .orderBy($"value") + .rows + .between + .currentRow + .and + .currentRow), + avg("key").over( + partitionBy($"key") + .orderBy($"value") + .rows + .between + .preceding(2) + .and + .preceding(1))), + Row(1.0, 1.0, 1.0) :: + Row(1.0, 1.0, 1.0) :: + Row(1.0, 1.0, 1.0) :: + Row(2.0, 2.0, 2.0) :: + Row(2.0, 2.0, 2.0) :: + Row(3.0, 3.0, 3.0) :: Nil) + } + + test("Multiple aggregate functions in range window") { + val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( $"key", - last("value").over - .partitionBy($"value") - .orderBy($"key") - .rows - .following(1) - .toColumn + last("value").over( + partitionBy($"value") + .orderBy($"key") + .range + .following(1)) .equalTo("2") .as("last_v"), avg("key") - .over - .partitionBy("value") - .orderBy("key") - .between - .preceding(2) - .following(1) - .toColumn.as("avg_key") + .over( + partitionBy("value") + .orderBy("key") + .range + .between + .preceding(2) + .and + .following(1)) + .as("avg_key1"), + avg("key") + .over( + partitionBy("value") + .orderBy("key") + .range + .between + .currentRow + .and + .following(1)) + .as("avg_key2"), + avg("key") + .over( + partitionBy("value") + .orderBy("key") + .range + .between + .preceding(1) + .and + .currentRow) + .as("avg_key3") ), - Row(1, false, 1.0) :: Row(2, false, 2.0) :: Row(2, true, 2.0) :: Nil) + Row(1, false, 1.0, 1.0, 1.0) :: + Row(1, false, 1.0, 1.0, 1.0) :: + Row(2, true, 2.0, 2.0, 2.0) :: + Row(2, true, 2.0, 2.0, 2.0) :: + Row(2, true, 2.0, 2.0, 2.0) :: + Row(2, true, 2.0, 2.0, 2.0) :: Nil) } } From 24a08eca538208a13a0eddfa53b61718557e3221 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 17 May 2015 22:37:46 +0800 Subject: [PATCH 06/16] scaladoc --- .../spark/sql/WindowFunctionDefinition.scala | 65 ++++++++----------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala index fce76755a8ae..ed66ac00afab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -27,22 +27,30 @@ import org.apache.spark.sql.catalyst.expressions._ * A set of methods for window function definition for aggregate expressions. * For example: * {{{ + * // predefine a window + * val w = partitionBy("name").orderBy("id") + * * df.select( + * first("value") + * over(w).as("first_value"), + * last("value") + * over(w).as("last_value"), * avg("value") - * .over - * .partitionBy("k1") + * over( + * partitionBy("k1") * .orderBy("k2", "k3") * .row - * .following(1) - * .toColumn.as("avg_value"), + * .following(1)).as("avg_value"), * max("value") - * .over - * .partitionBy("k2") + * .over( + * partitionBy("k2") * .orderBy("k3") + * .range * .between * .preceding(4) - * .following(3) - * .toColumn.as("max_value")) + * .and + * .following(3)).as("max_value")) + * * }}} * * @@ -63,8 +71,8 @@ class WindowFunctionDefinition protected[sql]( * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. * {{{ * // The following 2 are equivalent - * df.over.partitionBy("k1", "k2", ...) - * df.over.partitionBy($"K1", $"k2", ...) + * df.over(partitionBy("k1", "k2", ...)) + * df.over(partitionBy($"K1", $"k2", ...)) * }}} * @group window_funcs */ @@ -76,7 +84,7 @@ class WindowFunctionDefinition protected[sql]( /** * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. For example: * {{{ - * df.over.partitionBy($"col1", $"col2") + * df.over(partitionBy($"col1", $"col2")) * }}} * @group window_funcs */ @@ -90,8 +98,8 @@ class WindowFunctionDefinition protected[sql]( * the partition. * {{{ * // The following 2 are equivalent - * df.over.partitionBy("k1").orderBy("k2", "k3") - * df.over.partitionBy("k1").orderBy($"k2", $"k3") + * df.over(partitionBy("k1").orderBy("k2", "k3")) + * df.over(partitionBy("k1").orderBy($"k2", $"k3")) * }}} * @group window_funcs */ @@ -104,7 +112,7 @@ class WindowFunctionDefinition protected[sql]( * Returns a new [[WindowFunctionDefinition]] sorted by the specified column within * the partition. For example * {{{ - * df.over.partitionBy("k1").orderBy($"k2", $"k3") + * df.over(partitionBy("k1").orderBy($"k2", $"k3")) * }}} * @group window_funcs */ @@ -125,7 +133,7 @@ class WindowFunctionDefinition protected[sql]( * which makes the usage more like the SQL. * For example: * {{{ - * df.over.partitionBy("k1").orderBy($"k2", $"k3").range.between + * df.over(partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(1).and.currentRow) * }}} * @group window_funcs */ @@ -139,7 +147,7 @@ class WindowFunctionDefinition protected[sql]( * upper bound. * For example: * {{{ - * df.over.partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(3).and + * df.over(partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(3).and.currentRow) * }}} * @group window_funcs */ @@ -151,7 +159,7 @@ class WindowFunctionDefinition protected[sql]( * Returns a new Ranged [[WindowFunctionDefinition]]. * For example: * {{{ - * df.over.partitionBy("k1").orderBy($"k2", $"k3").range.between + * df.over(partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(3).and.currentRow) * }}} * @group window_funcs */ @@ -164,7 +172,7 @@ class WindowFunctionDefinition protected[sql]( * Returns a new [[WindowFunctionDefinition]], with fixed number of records. * For example: * {{{ - * df.over.partitionBy("k1").orderBy($"k2", $"k3").row + * df.over(partitionBy("k1").orderBy($"k2", $"k3").rows) * }}} * @group window_funcs */ @@ -285,26 +293,7 @@ class WindowFunctionDefinition protected[sql]( } /** - * Convert the window definition into a new Column. - * Currently, only aggregate expressions are supported for window function. For Example: - * {{{ - * df.select( - * avg("value") - * .over - * .partitionBy("k1") - * .orderBy($"k2", $"k3") - * .row - * .following(1) - * .toColumn.as("avg_value"), - * max("value") - * .over - * .partitionBy("k2") - * .orderBy("k3") - * .between - * .preceding(4) - * .following(3) - * .toColumn.as("max_value")) - * }}} + * Convert the window definition into a Column object. * @group window_funcs */ private[sql] def toColumn: Column = { From 57e3bc032ace248f2a1083e1db66a9afd9ea2866 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 17 May 2015 22:49:49 +0800 Subject: [PATCH 07/16] typos --- .../spark/sql/WindowFunctionDefinition.scala | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala index ed66ac00afab..f3938e881b41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._ * over( * partitionBy("k1") * .orderBy("k2", "k3") - * .row + * .rows * .following(1)).as("avg_value"), * max("value") * .over( @@ -53,7 +53,13 @@ import org.apache.spark.sql.catalyst.expressions._ * * }}} * - * + * @param column The bounded the aggregate/window function + * @param partitionSpec The partition of the window + * @param orderSpec The ordering of the window + * @param frame The Window Frame type + * @param bindLower A hint of when call the methods `.preceding(n)` `.currentRow()` `.following()` + * if bindLower == true, then we will set the lower bound, otherwise, we should + * set the upper bound for the Row/Range Frame. */ @Experimental class WindowFunctionDefinition protected[sql]( @@ -183,15 +189,15 @@ class WindowFunctionDefinition protected[sql]( /** * Returns a new [[WindowFunctionDefinition]], with position specified preceding of CURRENT_ROW. - * It can be either Lower or Upper Bound position, depends on whether the `and` method called. + * It can be either Lower or Upper Bound position depends on the semantic context. * For example: * {{{ * // [CURRENT_ROW - 1, ~) - * df.over(partitionBy("k1").orderBy("k2").row.preceding(1)) + * df.over(partitionBy("k1").orderBy("k2").rows.preceding(1)) * // [CURRENT_ROW - 3, CURRENT_ROW - 1] - * df.over(partitionBy("k1").orderBy("k2").row.between.preceding(3).and.preceding(1)) + * df.over(partitionBy("k1").orderBy("k2").rows.between.preceding(3).and.preceding(1)) * // (~, CURRENT_ROW - 1] - * df.over(partitionBy("k1").orderBy("k2").row.between.unboundedPreceding.and.preceding(1)) + * df.over(partitionBy("k1").orderBy("k2").rows.between.unboundedPreceding.and.preceding(1)) * }}} * @group window_funcs */ @@ -212,7 +218,7 @@ class WindowFunctionDefinition protected[sql]( * For example: * {{{ * // (~, CURRENT_ROW] - * df.over(partitionBy("k1").orderBy("k2").row.between.unboundedPreceding.and.currentRow) + * df.over(partitionBy("k1").orderBy("k2").rows.between.unboundedPreceding.and.currentRow) * }}} * @group window_funcs */ @@ -230,7 +236,7 @@ class WindowFunctionDefinition protected[sql]( * For example: * {{{ * // [CURRENT_ROW, ~) - * df.over(partitionBy("k1").orderBy("k2").row.between.currentRow.and.unboundedFollowing) + * df.over(partitionBy("k1").orderBy("k2").rows.between.currentRow.and.unboundedFollowing) * }}} * @group window_funcs */ @@ -245,13 +251,13 @@ class WindowFunctionDefinition protected[sql]( /** * Returns a new [[WindowFunctionDefinition]], with position as CURRENT_ROW. - * It can be either Lower or Upper Bound position, depends on whether the `and` method called. + * It can be either Lower or Upper Bound position, depends on the semantic context. * For example: * {{{ * // [CURRENT_ROW, ~) - * df.over(partitionBy("k1").orderBy("k2").row.between.currentRow.and.unboundedFollowing) + * df.over(partitionBy("k1").orderBy("k2").rows.between.currentRow.and.unboundedFollowing) * // [CURRENT_ROW - 3, CURRENT_ROW] - * df.over(partitionBy("k1").orderBy("k2").row.between.preceding(3).and.currentRow) + * df.over(partitionBy("k1").orderBy("k2").rows.between.preceding(3).and.currentRow) * }}} * @group window_funcs */ @@ -268,15 +274,15 @@ class WindowFunctionDefinition protected[sql]( /** * Returns a new [[WindowFunctionDefinition]], with position specified following of CURRENT_ROW. - * It can be either Lower or Upper Bound position, depends on whether the `and` method called. + * It can be either Lower or Upper Bound position, depends on the semantic context. * For example: * {{{ * // [CURRENT_ROW + 1, ~) - * df.over(partitionBy("k1").orderBy("k2").row.following(1)) + * df.over(partitionBy("k1").orderBy("k2").rows.following(1)) * // [CURRENT_ROW + 1, CURRENT_ROW + 3] - * df.over(partitionBy("k1").orderBy("k2").row.between.following(1).and.following(3)) + * df.over(partitionBy("k1").orderBy("k2").rows.between.following(1).and.following(3)) * // [CURRENT_ROW + 1, ~) - * df.over(partitionBy("k1").orderBy("k2").row.between.following(1).and.unboundedFollowing) + * df.over(partitionBy("k1").orderBy("k2").rows.between.following(1).and.unboundedFollowing) * }}} * @group window_funcs */ From 68478250f67681ccbe2fe70c216b60f96e32e5d8 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 17 May 2015 22:39:17 -0700 Subject: [PATCH 08/16] Add additional analystcs functions --- .../spark/sql/WindowFunctionDefinition.scala | 4 - .../org/apache/spark/sql/functions.scala | 82 +++++++++++++ .../sql/hive/HiveDataFrameWindowSuite.scala | 113 ++++++++---------- 3 files changed, 132 insertions(+), 67 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala index f3938e881b41..96a1ad7c6580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -333,10 +333,6 @@ class WindowFunctionDefinition protected[sql]( case wf: WindowFunction => WindowExpression( wf, WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case aggr: AggregateExpression => - throw new UnsupportedOperationException( - """Only support Aggregate Functions: - | avg, sum, count, first, last, min, max for now""".stripMargin) case x => throw new UnsupportedOperationException(s"We don't support $x in window operation.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c21cbd95ed67..79f89e6d483c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -501,6 +501,88 @@ object functions { new WindowFunctionDefinition().orderBy(cols: _*) } + /** + * NTILE for specified expression. + * NTILE allows easy calculation of tertiles, quartiles, deciles and other + * common summary statistics. This function divides an ordered partition into a specified + * number of groups called buckets and assigns a bucket number to each row in the partition. + * + * @group window_funcs + */ + def ntile(e: Column): Column = { + UnresolvedWindowFunction("ntile", e.expr :: Nil) + } + + /** + * NTILE for specified column. + * NTILE allows easy calculation of tertiles, quartiles, deciles and other + * common summary statistics. This function divides an ordered partition into a specified + * number of groups called buckets and assigns a bucket number to each row in the partition. + * + * @group window_funcs + */ + def ntile(columnName: String): Column = { + ntile(Column(columnName)) + } + + /** + * Assigns a unique number (sequentially, starting from 1, as defined by ORDER BY) to each + * row within the partition. + * + * @group window_funcs + */ + def rowNumber(): Column = { + UnresolvedWindowFunction("row_number", Nil) + } + + /** + * The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * @group window_funcs + */ + def denseRank(): Column = { + UnresolvedWindowFunction("dense_rank", Nil) + } + + /** + * The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * @group window_funcs + */ + def rank(): Column = { + UnresolvedWindowFunction("rank", Nil) + } + + /** + * CUME_DIST (defined as the inverse of percentile in some statistical books) computes + * the position of a specified value relative to a set of values. + * To compute the CUME_DIST of a value x in a set S of size N, you use the formula: + * CUME_DIST(x) = number of values in S coming before and including x in the specified order / N + * + * @group window_funcs + */ + def cumeDist(): Column = { + UnresolvedWindowFunction("cume_dist", Nil) + } + + /** + * PERCENT_RANK is similar to CUME_DIST, but it uses rank values rather than row counts + * in its numerator. + * The formula: + * (rank of row in its partition - 1) / (number of rows in the partition - 1) + * + * @group window_funcs + */ + def percentRank(): Column = { + UnresolvedWindowFunction("percent_rank", Nil) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 2b49f383b2ab..3ba9bef2da9e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{AnalysisException, Row, QueryTest} +import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ class HiveDataFrameWindowSuite extends QueryTest { @@ -59,7 +60,7 @@ class HiveDataFrameWindowSuite extends QueryTest { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - lead("value").over( + lag("value").over( partitionBy($"key") .orderBy($"value"))), Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) @@ -82,15 +83,56 @@ class HiveDataFrameWindowSuite extends QueryTest { (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( - lead("value", 2, "n/a").over( + lag("value", 2, "n/a").over( partitionBy($"key") .orderBy($"value"))), Row("1") :: Row("1") :: Row("2") :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) } + test("rank functions in unspecific window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + ntile("key").over( + partitionBy("value") + .orderBy("key")), + ntile($"key").over( + partitionBy("value") + .orderBy("key")), + rowNumber().over( + partitionBy("value") + .orderBy("key")), + denseRank().over( + partitionBy("value") + .orderBy("key")), + rank().over( + partitionBy("value") + .orderBy("key")), + cumeDist().over( + partitionBy("value") + .orderBy("key")), + percentRank().over( + partitionBy("value") + .orderBy("key"))), + sql( + s"""SELECT + |key, + |ntile(key) over (partition by value order by key), + |ntile(key) over (partition by value order by key), + |row_number() over (partition by value order by key), + |dense_rank() over (partition by value order by key), + |rank() over (partition by value order by key), + |cume_dist() over (partition by value order by key), + |percent_rank() over (partition by value order by key) + |FROM window_table""".stripMargin).collect) + } + test("aggregation in a row window") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over( @@ -106,6 +148,7 @@ class HiveDataFrameWindowSuite extends QueryTest { test("aggregation in a Range window") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over( @@ -119,68 +162,9 @@ class HiveDataFrameWindowSuite extends QueryTest { Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) } - test("multiple aggregate function in window") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.select( - avg("key").over( - partitionBy($"value") - .orderBy($"key") - .rows - .preceding(1)), - sum("key").over( - partitionBy($"value") - .orderBy($"key") - .range - .between - .preceding(1) - .and - .following(1))), - Row(1.0, 2) :: Row(1.0, 2) :: Row(2.0, 4) :: Row(2.0, 4) :: Nil) - } - - test("Window function in Unspecified Window") { - val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") - - checkAnswer( - df.select( - $"key", - first("value").over( - partitionBy($"key"))), - Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil) - } - - test("Window function in Unspecified Window #2") { - val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") - - checkAnswer( - df.select( - $"key", - first("value").over( - partitionBy($"key") - .orderBy($"value"))), - Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil) - } - - test("Aggregate function in Range Window") { - val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") - - checkAnswer( - df.select( - $"key", - first("value").over( - partitionBy($"value") - .orderBy($"key") - .range - .between - .preceding(1) - .and - .following(1))), - Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) - } - test("Aggregate function in Row preceding Window") { val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( $"key", @@ -194,6 +178,7 @@ class HiveDataFrameWindowSuite extends QueryTest { test("Aggregate function in Row following Window") { val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( $"key", @@ -207,6 +192,7 @@ class HiveDataFrameWindowSuite extends QueryTest { test("Multiple aggregate functions in row window") { val df = Seq((1, "1"), (1, "2"), (3, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over( @@ -240,6 +226,7 @@ class HiveDataFrameWindowSuite extends QueryTest { test("Multiple aggregate functions in range window") { val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( $"key", From f3fd2d02a3ec7c4135005514a12038747594f0dc Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 17 May 2015 23:26:41 -0700 Subject: [PATCH 09/16] polish the unit test --- .../sql/hive/HiveDataFrameWindowSuite.scala | 148 +++++++++++++++--- 1 file changed, 122 insertions(+), 26 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 3ba9bef2da9e..309aab31b45d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -48,54 +48,85 @@ class HiveDataFrameWindowSuite extends QueryTest { test("lead in window") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( df.select( lead("value").over( partitionBy($"key") .orderBy($"value"))), - Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) + sql( + """SELECT + | lead(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) } test("lag in window") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( df.select( lag("value").over( partitionBy($"key") .orderBy($"value"))), - Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) + sql( + """SELECT + | lag(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) } test("lead in window with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( lead("value", 2, "n/a").over( partitionBy("key") .orderBy("value"))), - Row("1") :: Row("1") :: Row("2") :: Row("n/a") - :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) + sql( + """SELECT + | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) } test("lag in window with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") checkAnswer( df.select( lag("value", 2, "n/a").over( partitionBy($"key") .orderBy($"value"))), - Row("1") :: Row("1") :: Row("2") :: Row("n/a") - :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) + sql( + """SELECT + | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) } test("rank functions in unspecific window") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", + max("key").over( + partitionBy("value") + .orderBy("key")), + min("key").over( + partitionBy("value") + .orderBy("key")), + mean("key").over( + partitionBy("value") + .orderBy("key")), + count("key").over( + partitionBy("value") + .orderBy("key")), + sum("key").over( + partitionBy("value") + .orderBy("key")), ntile("key").over( partitionBy("value") .orderBy("key")), @@ -120,6 +151,11 @@ class HiveDataFrameWindowSuite extends QueryTest { sql( s"""SELECT |key, + |max(key) over (partition by value order by key), + |min(key) over (partition by value order by key), + |avg(key) over (partition by value order by key), + |count(key) over (partition by value order by key), + |sum(key) over (partition by value order by key), |ntile(key) over (partition by value order by key), |ntile(key) over (partition by value order by key), |row_number() over (partition by value order by key), @@ -143,7 +179,11 @@ class HiveDataFrameWindowSuite extends QueryTest { .preceding(1) .and .following(1))), - Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) } test("aggregation in a Range window") { @@ -159,11 +199,15 @@ class HiveDataFrameWindowSuite extends QueryTest { .preceding(1) .and .following(1))), - Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) } test("Aggregate function in Row preceding Window") { - val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "3"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( @@ -172,22 +216,65 @@ class HiveDataFrameWindowSuite extends QueryTest { partitionBy($"value") .orderBy($"key") .rows - .preceding(1))), - Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) + .preceding(1)), + first("value").over( + partitionBy($"value") + .orderBy($"key") + .rows + .between + .preceding(2) + .and + .preceding(1))), + sql( + """SELECT + | key, + | first_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS 1 preceding), + | first_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 2 preceding and 1 preceding) + | FROM window_table""".stripMargin).collect()) } test("Aggregate function in Row following Window") { - val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", + last("value").over( + partitionBy($"value") + .orderBy($"key") + .rows + .between + .currentRow() + .and + .unboundedFollowing()), + last("value").over( + partitionBy($"value") + .orderBy($"key") + .rows + .between + .unboundedPreceding() + .and + .currentRow()), last("value").over( partitionBy($"value") .orderBy($"key") .rows + .between + .preceding(1) + .and .following(1))), - Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil) + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) + | FROM window_table""".stripMargin).collect()) } test("Multiple aggregate functions in row window") { @@ -216,12 +303,15 @@ class HiveDataFrameWindowSuite extends QueryTest { .preceding(2) .and .preceding(1))), - Row(1.0, 1.0, 1.0) :: - Row(1.0, 1.0, 1.0) :: - Row(1.0, 1.0, 1.0) :: - Row(2.0, 2.0, 2.0) :: - Row(2.0, 2.0, 2.0) :: - Row(3.0, 3.0, 3.0) :: Nil) + sql( + """SELECT + | avg(key) OVER + | (partition by key ORDER BY value rows 1 preceding), + | avg(key) OVER + | (partition by key ORDER BY value rows between current row and current row), + | avg(key) OVER + | (partition by key ORDER BY value rows between 2 preceding and 1 preceding) + | FROM window_table""".stripMargin).collect()) } test("Multiple aggregate functions in range window") { @@ -268,11 +358,17 @@ class HiveDataFrameWindowSuite extends QueryTest { .currentRow) .as("avg_key3") ), - Row(1, false, 1.0, 1.0, 1.0) :: - Row(1, false, 1.0, 1.0, 1.0) :: - Row(2, true, 2.0, 2.0, 2.0) :: - Row(2, true, 2.0, 2.0, 2.0) :: - Row(2, true, 2.0, 2.0, 2.0) :: - Row(2, true, 2.0, 2.0, 2.0) :: Nil) + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) + | FROM window_table""".stripMargin).collect()) } } From 3b1865f3342f4b142383d0ab47fd951d5df05e65 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 17 May 2015 23:28:48 -0700 Subject: [PATCH 10/16] scaladoc typos --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f096b17106dc..aa404bddd77b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -894,8 +894,8 @@ class Column(protected[sql] val expr: Expression) extends Logging { * {{{ * val w = partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w).range.preceding(2), - * avg("price").over(w).range.preceding(4), + * sum("price").over(w.range.preceding(2)), + * avg("price").over(w.range.preceding(4)), * avg("price").over(partitionBy("name").orderBy("id).range.preceding(1)) * ) * }}} From c141fb1e2ec99aa616770d120abac1b0fbb1fa58 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 18 May 2015 19:58:46 -0700 Subject: [PATCH 11/16] hide all of properties of the WindowFunctionDefinition --- .../spark/sql/WindowFunctionDefinition.scala | 69 ++++++++++++++----- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala index 96a1ad7c6580..b786841a4470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -53,21 +53,32 @@ import org.apache.spark.sql.catalyst.expressions._ * * }}} * - * @param column The bounded the aggregate/window function - * @param partitionSpec The partition of the window - * @param orderSpec The ordering of the window - * @param frame The Window Frame type - * @param bindLower A hint of when call the methods `.preceding(n)` `.currentRow()` `.following()` - * if bindLower == true, then we will set the lower bound, otherwise, we should - * set the upper bound for the Row/Range Frame. */ @Experimental -class WindowFunctionDefinition protected[sql]( - column: Column = null, - partitionSpec: Seq[Expression] = Nil, - orderSpec: Seq[SortOrder] = Nil, - frame: WindowFrame = UnspecifiedFrame, - bindLower: Boolean = true) { +class WindowFunctionDefinition { + private var column: Column = _ + private var partitionSpec: Seq[Expression] = Nil + private var orderSpec: Seq[SortOrder] = Nil + private var frame: WindowFrame = UnspecifiedFrame + + // Hint of when call the methods `.preceding(n)` `.currentRow()` `.following()` + // if bindLower == true, then we will set the lower bound, otherwise, we should + // set the upper bound for the Row/Range Frame. + private var bindLower: Boolean = true + + private def this( + column: Column = null, + partitionSpec: Seq[Expression] = Nil, + orderSpec: Seq[SortOrder] = Nil, + frame: WindowFrame = UnspecifiedFrame, + bindLower: Boolean = true) { + this() + this.column = column + this.partitionSpec = partitionSpec + this.orderSpec = orderSpec + this.frame = frame + this.bindLower = bindLower + } private[sql] def newColumn(c: Column): WindowFunctionDefinition = { new WindowFunctionDefinition(c, partitionSpec, orderSpec, frame, bindLower) @@ -192,6 +203,8 @@ class WindowFunctionDefinition protected[sql]( * It can be either Lower or Upper Bound position depends on the semantic context. * For example: * {{{ + * // [CURRENT_ROW, ~) + * df.over(partitionBy("k1").orderBy("k2").rows.preceding(0)) * // [CURRENT_ROW - 1, ~) * df.over(partitionBy("k1").orderBy("k2").rows.preceding(1)) * // [CURRENT_ROW - 3, CURRENT_ROW - 1] @@ -202,12 +215,20 @@ class WindowFunctionDefinition protected[sql]( * @group window_funcs */ def preceding(n: Int): WindowFunctionDefinition = { - assert(n > 0) + require(n >= 0, s"preceding(n) requires n greater than or equals 0, but got $n") val newFrame = frame match { case f: SpecifiedWindowFrame if bindLower => - f.copy(frameStart = ValuePreceding(n)) + if (n > 0) { + f.copy(frameStart = ValuePreceding(n)) + } else { + f.copy(frameStart = CurrentRow) + } case f: SpecifiedWindowFrame => - f.copy(frameEnd = ValuePreceding(n)) + if (n > 0) { + f.copy(frameEnd = ValuePreceding(n)) + } else { + f.copy(frameEnd = CurrentRow) + } case f => throw new UnsupportedOperationException(s"preceding on $f") } new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) @@ -277,6 +298,8 @@ class WindowFunctionDefinition protected[sql]( * It can be either Lower or Upper Bound position, depends on the semantic context. * For example: * {{{ + * // [CURRENT_ROW, ~) + * df.over(partitionBy("k1").orderBy("k2").rows.following(0)) * // [CURRENT_ROW + 1, ~) * df.over(partitionBy("k1").orderBy("k2").rows.following(1)) * // [CURRENT_ROW + 1, CURRENT_ROW + 3] @@ -287,12 +310,20 @@ class WindowFunctionDefinition protected[sql]( * @group window_funcs */ def following(n: Int): WindowFunctionDefinition = { - assert(n > 0) + require(n >= 0, s"following(n) requires n greater than or equals 0, but got $n") val newFrame = frame match { case f: SpecifiedWindowFrame if bindLower => - f.copy(frameStart = ValueFollowing(n)) + if (n > 0) { + f.copy(frameStart = ValueFollowing(n)) + } else { + f.copy(frameStart = CurrentRow) + } case f: SpecifiedWindowFrame => - f.copy(frameEnd = ValueFollowing(n)) + if (n > 0) { + f.copy(frameEnd = ValueFollowing(n)) + } else { + f.copy(frameEnd = CurrentRow) + } case f => throw new UnsupportedOperationException(s"following on $f") } new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) From d625a642735e3b741c5d701e2fe28a0b6c8c845f Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 22 May 2015 01:09:04 +0800 Subject: [PATCH 12/16] Update the dataframe window API as suggsted --- .../scala/org/apache/spark/sql/Column.scala | 6 +- .../scala/org/apache/spark/sql/Window.scala | 222 +++++++++++ .../spark/sql/WindowFunctionDefinition.scala | 372 ------------------ .../org/apache/spark/sql/functions.scala | 56 +-- .../spark/sql/hive/JavaDataFrameSuite.java | 87 ++++ .../sql/hive/HiveDataFrameWindowSuite.scala | 154 +++----- 6 files changed, 373 insertions(+), 524 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Window.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index aa404bddd77b..d9a715fd645f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -890,9 +890,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) /** - * Define a [[WindowFunctionDefinition]] column. + * Define a [[Window]] column. * {{{ - * val w = partitionBy("name").orderBy("id") + * val w = Window.partitionBy("name").orderBy("id") * df.select( * sum("price").over(w.range.preceding(2)), * avg("price").over(w.range.preceding(4)), @@ -902,7 +902,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def over(w: WindowFunctionDefinition): Column = w.newColumn(this).toColumn + def over(w: Window): Column = w.newColumn(this).toColumn } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/Window.scala new file mode 100644 index 000000000000..80272f380b11 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Window.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ + + +sealed private[sql] class Frame(private[sql] var boundary: FrameBoundary = null) + +/** + * :: Experimental :: + * An utility to specify the Window Frame Range. + */ +object Frame { + val currentRow: Frame = new Frame(CurrentRow) + val unbounded: Frame = new Frame() + def preceding(n: Int): Frame = if (n == 0) { + new Frame(CurrentRow) + } else { + new Frame(ValuePreceding(n)) + } + + def following(n: Int): Frame = if (n == 0) { + new Frame(CurrentRow) + } else { + new Frame(ValueFollowing(n)) + } +} + +/** + * :: Experimental :: + * A Window object with everything unset. But can build new Window object + * based on it. + */ +@Experimental +object Window extends Window() + +/** + * :: Experimental :: + * A set of methods for window function definition for aggregate expressions. + * For example: + * {{{ + * // predefine a window + * val w = Window.partitionBy("name").orderBy("id") + * .rowsBetween(Frame.unbounded, Frame.currentRow) + * df.select( + * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") + * .rowsBetween(Frame.unbounded, Frame.currentRow)) + * ) + * + * df.select( + * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") + * .rowsBetween(Frame.preceding(50), Frame.following(10))) + * ) + * + * }}} + * + */ +@Experimental +class Window { + private var column: Column = _ + private var partitionSpec: Seq[Expression] = Nil + private var orderSpec: Seq[SortOrder] = Nil + private var frame: WindowFrame = UnspecifiedFrame + + private def this( + column: Column = null, + partitionSpec: Seq[Expression] = Nil, + orderSpec: Seq[SortOrder] = Nil, + frame: WindowFrame = UnspecifiedFrame) { + this() + this.column = column + this.partitionSpec = partitionSpec + this.orderSpec = orderSpec + this.frame = frame + } + + private[sql] def newColumn(c: Column): Window = { + new Window(c, partitionSpec, orderSpec, frame) + } + + /** + * Returns a new [[Window]] partitioned by the specified column. + * {{{ + * // The following 2 are equivalent + * df.over(Window.partitionBy("k1", "k2", ...)) + * df.over(Window.partitionBy($"K1", $"k2", ...)) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): Window = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Returns a new [[Window]] partitioned by the specified column. For example: + * {{{ + * df.over(Window.partitionBy($"col1", $"col2")) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): Window = { + new Window(column, cols.map(_.expr), orderSpec, frame) + } + + /** + * Returns a new [[Window]] sorted by the specified column within + * the partition. + * {{{ + * // The following 2 are equivalent + * df.over(Window.partitionBy("k1").orderBy("k2", "k3")) + * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): Window = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Returns a new [[Window]] sorted by the specified column within + * the partition. For example + * {{{ + * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def orderBy(cols: Column*): Window = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new Window(column, partitionSpec, sortOrder, frame) + } + + def rowsBetween(start: Frame, end: Frame): Window = { + assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") + assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") + + val s = if (start.boundary == null) UnboundedPreceding else start.boundary + val e = if (end.boundary == null) UnboundedFollowing else end.boundary + + new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RowFrame, s, e)) + } + + def rangeBetween(start: Frame, end: Frame): Window = { + assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") + assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") + + val s = if (start.boundary == null) UnboundedPreceding else start.boundary + val e = if (end.boundary == null) UnboundedFollowing else end.boundary + + new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RangeFrame, s, e)) + } + + /** + * Convert the window definition into a Column object. + * @group window_funcs + */ + private[sql] def toColumn: Column = { + if (column == null) { + throw new AnalysisException("Window didn't bind with expression") + } + val windowExpr = column.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"We don't support $x in window operation.") + } + new Column(windowExpr) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala deleted file mode 100644 index b786841a4470..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.language.implicitConversions - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.expressions._ - -/** - * :: Experimental :: - * A set of methods for window function definition for aggregate expressions. - * For example: - * {{{ - * // predefine a window - * val w = partitionBy("name").orderBy("id") - * - * df.select( - * first("value") - * over(w).as("first_value"), - * last("value") - * over(w).as("last_value"), - * avg("value") - * over( - * partitionBy("k1") - * .orderBy("k2", "k3") - * .rows - * .following(1)).as("avg_value"), - * max("value") - * .over( - * partitionBy("k2") - * .orderBy("k3") - * .range - * .between - * .preceding(4) - * .and - * .following(3)).as("max_value")) - * - * }}} - * - */ -@Experimental -class WindowFunctionDefinition { - private var column: Column = _ - private var partitionSpec: Seq[Expression] = Nil - private var orderSpec: Seq[SortOrder] = Nil - private var frame: WindowFrame = UnspecifiedFrame - - // Hint of when call the methods `.preceding(n)` `.currentRow()` `.following()` - // if bindLower == true, then we will set the lower bound, otherwise, we should - // set the upper bound for the Row/Range Frame. - private var bindLower: Boolean = true - - private def this( - column: Column = null, - partitionSpec: Seq[Expression] = Nil, - orderSpec: Seq[SortOrder] = Nil, - frame: WindowFrame = UnspecifiedFrame, - bindLower: Boolean = true) { - this() - this.column = column - this.partitionSpec = partitionSpec - this.orderSpec = orderSpec - this.frame = frame - this.bindLower = bindLower - } - - private[sql] def newColumn(c: Column): WindowFunctionDefinition = { - new WindowFunctionDefinition(c, partitionSpec, orderSpec, frame, bindLower) - } - - /** - * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. - * {{{ - * // The following 2 are equivalent - * df.over(partitionBy("k1", "k2", ...)) - * df.over(partitionBy($"K1", $"k2", ...)) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(colName: String, colNames: String*): WindowFunctionDefinition = { - partitionBy((colName +: colNames).map(Column(_)): _*) - } - - /** - * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. For example: - * {{{ - * df.over(partitionBy($"col1", $"col2")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(cols: Column*): WindowFunctionDefinition = { - new WindowFunctionDefinition(column, cols.map(_.expr), orderSpec, frame) - } - - /** - * Returns a new [[WindowFunctionDefinition]] sorted by the specified column within - * the partition. - * {{{ - * // The following 2 are equivalent - * df.over(partitionBy("k1").orderBy("k2", "k3")) - * df.over(partitionBy("k1").orderBy($"k2", $"k3")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def orderBy(colName: String, colNames: String*): WindowFunctionDefinition = { - orderBy((colName +: colNames).map(Column(_)): _*) - } - - /** - * Returns a new [[WindowFunctionDefinition]] sorted by the specified column within - * the partition. For example - * {{{ - * df.over(partitionBy("k1").orderBy($"k2", $"k3")) - * }}} - * @group window_funcs - */ - def orderBy(cols: Column*): WindowFunctionDefinition = { - val sortOrder: Seq[SortOrder] = cols.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - new WindowFunctionDefinition(column, partitionSpec, sortOrder, frame) - } - - /** - * Returns the current [[WindowFunctionDefinition]]. This is a dummy function, - * which makes the usage more like the SQL. - * For example: - * {{{ - * df.over(partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(1).and.currentRow) - * }}} - * @group window_funcs - */ - def between: WindowFunctionDefinition = { - assert(this.frame.isInstanceOf[SpecifiedWindowFrame], "Should be a WindowFrame.") - new WindowFunctionDefinition(column, partitionSpec, orderSpec, frame, true) - } - - /** - * Returns a new [[WindowFunctionDefinition]] indicate that we need to specify the - * upper bound. - * For example: - * {{{ - * df.over(partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(3).and.currentRow) - * }}} - * @group window_funcs - */ - def and: WindowFunctionDefinition = { - new WindowFunctionDefinition(column, partitionSpec, orderSpec, frame, false) - } - - /** - * Returns a new Ranged [[WindowFunctionDefinition]]. - * For example: - * {{{ - * df.over(partitionBy("k1").orderBy($"k2", $"k3").range.between.preceding(3).and.currentRow) - * }}} - * @group window_funcs - */ - def range: WindowFunctionDefinition = { - new WindowFunctionDefinition(column, partitionSpec, orderSpec, - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, UnboundedFollowing)) - } - - /** - * Returns a new [[WindowFunctionDefinition]], with fixed number of records. - * For example: - * {{{ - * df.over(partitionBy("k1").orderBy($"k2", $"k3").rows) - * }}} - * @group window_funcs - */ - def rows: WindowFunctionDefinition = { - new WindowFunctionDefinition(column, partitionSpec, orderSpec, - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)) - } - - /** - * Returns a new [[WindowFunctionDefinition]], with position specified preceding of CURRENT_ROW. - * It can be either Lower or Upper Bound position depends on the semantic context. - * For example: - * {{{ - * // [CURRENT_ROW, ~) - * df.over(partitionBy("k1").orderBy("k2").rows.preceding(0)) - * // [CURRENT_ROW - 1, ~) - * df.over(partitionBy("k1").orderBy("k2").rows.preceding(1)) - * // [CURRENT_ROW - 3, CURRENT_ROW - 1] - * df.over(partitionBy("k1").orderBy("k2").rows.between.preceding(3).and.preceding(1)) - * // (~, CURRENT_ROW - 1] - * df.over(partitionBy("k1").orderBy("k2").rows.between.unboundedPreceding.and.preceding(1)) - * }}} - * @group window_funcs - */ - def preceding(n: Int): WindowFunctionDefinition = { - require(n >= 0, s"preceding(n) requires n greater than or equals 0, but got $n") - val newFrame = frame match { - case f: SpecifiedWindowFrame if bindLower => - if (n > 0) { - f.copy(frameStart = ValuePreceding(n)) - } else { - f.copy(frameStart = CurrentRow) - } - case f: SpecifiedWindowFrame => - if (n > 0) { - f.copy(frameEnd = ValuePreceding(n)) - } else { - f.copy(frameEnd = CurrentRow) - } - case f => throw new UnsupportedOperationException(s"preceding on $f") - } - new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) - } - - /** - * Returns a new [[WindowFunctionDefinition]], with lower position as unbounded. - * For example: - * {{{ - * // (~, CURRENT_ROW] - * df.over(partitionBy("k1").orderBy("k2").rows.between.unboundedPreceding.and.currentRow) - * }}} - * @group window_funcs - */ - def unboundedPreceding(): WindowFunctionDefinition = { - val newFrame = frame match { - case f : SpecifiedWindowFrame => - f.copy(frameStart = UnboundedPreceding) - case f => throw new UnsupportedOperationException(s"unboundedPreceding on $f") - } - new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) - } - - /** - * Returns a new [[WindowFunctionDefinition]], with upper position as unbounded. - * For example: - * {{{ - * // [CURRENT_ROW, ~) - * df.over(partitionBy("k1").orderBy("k2").rows.between.currentRow.and.unboundedFollowing) - * }}} - * @group window_funcs - */ - def unboundedFollowing(): WindowFunctionDefinition = { - val newFrame = frame match { - case f : SpecifiedWindowFrame => - f.copy(frameEnd = UnboundedFollowing) - case f => throw new UnsupportedOperationException(s"unboundedFollowing on $f") - } - new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) - } - - /** - * Returns a new [[WindowFunctionDefinition]], with position as CURRENT_ROW. - * It can be either Lower or Upper Bound position, depends on the semantic context. - * For example: - * {{{ - * // [CURRENT_ROW, ~) - * df.over(partitionBy("k1").orderBy("k2").rows.between.currentRow.and.unboundedFollowing) - * // [CURRENT_ROW - 3, CURRENT_ROW] - * df.over(partitionBy("k1").orderBy("k2").rows.between.preceding(3).and.currentRow) - * }}} - * @group window_funcs - */ - def currentRow(): WindowFunctionDefinition = { - val newFrame = frame match { - case f : SpecifiedWindowFrame if bindLower => - f.copy(frameStart = CurrentRow) - case f : SpecifiedWindowFrame => - f.copy(frameEnd = CurrentRow) - case f => throw new UnsupportedOperationException(s"currentRow on $f") - } - new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) - } - - /** - * Returns a new [[WindowFunctionDefinition]], with position specified following of CURRENT_ROW. - * It can be either Lower or Upper Bound position, depends on the semantic context. - * For example: - * {{{ - * // [CURRENT_ROW, ~) - * df.over(partitionBy("k1").orderBy("k2").rows.following(0)) - * // [CURRENT_ROW + 1, ~) - * df.over(partitionBy("k1").orderBy("k2").rows.following(1)) - * // [CURRENT_ROW + 1, CURRENT_ROW + 3] - * df.over(partitionBy("k1").orderBy("k2").rows.between.following(1).and.following(3)) - * // [CURRENT_ROW + 1, ~) - * df.over(partitionBy("k1").orderBy("k2").rows.between.following(1).and.unboundedFollowing) - * }}} - * @group window_funcs - */ - def following(n: Int): WindowFunctionDefinition = { - require(n >= 0, s"following(n) requires n greater than or equals 0, but got $n") - val newFrame = frame match { - case f: SpecifiedWindowFrame if bindLower => - if (n > 0) { - f.copy(frameStart = ValueFollowing(n)) - } else { - f.copy(frameStart = CurrentRow) - } - case f: SpecifiedWindowFrame => - if (n > 0) { - f.copy(frameEnd = ValueFollowing(n)) - } else { - f.copy(frameEnd = CurrentRow) - } - case f => throw new UnsupportedOperationException(s"following on $f") - } - new WindowFunctionDefinition(column, partitionSpec, orderSpec, newFrame, false) - } - - /** - * Convert the window definition into a Column object. - * @group window_funcs - */ - private[sql] def toColumn: Column = { - if (column == null) { - throw new AnalysisException("Window didn't bind with expression") - } - val windowExpr = column.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"We don't support $x in window operation.") - } - new Column(windowExpr) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 79f89e6d483c..9188d9719242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.Utils * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname window_funcs Window functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -446,61 +447,6 @@ object functions { UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) } - /** - * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. - * For example: - * {{{ - * // The following 2 are equivalent - * partitionBy("k1", "k2").orderBy("k3") - * partitionBy($"K1", $"k2").orderBy($"k3") - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(colName: String, colNames: String*): WindowFunctionDefinition = { - new WindowFunctionDefinition().partitionBy(colName, colNames: _*) - } - - /** - * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. - * For example: - * {{{ - * partitionBy($"col1", $"col2").orderBy("value") - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(cols: Column*): WindowFunctionDefinition = { - new WindowFunctionDefinition().partitionBy(cols: _*) - } - - /** - * Create a new [[WindowFunctionDefinition]] sorted by the specified columns. - * For example: - * {{{ - * // The following 2 are equivalent - * orderBy("k2", "k3").partitionBy("k1") - * orderBy($"k2", $"k3").partitionBy("k1") - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def orderBy(colName: String, colNames: String*): WindowFunctionDefinition = { - new WindowFunctionDefinition().orderBy(colName, colNames: _*) - } - - /** - * Returns a new [[WindowFunctionDefinition]] sorted by the specified columns. - * For example - * {{{ - * val w = orderBy($"k2", $"k3").partitionBy("k1") - * }}} - * @group window_funcs - */ - def orderBy(cols: Column*): WindowFunctionDefinition = { - new WindowFunctionDefinition().orderBy(cols: _*) - } - /** * NTILE for specified expression. * NTILE allows easy calculation of tertiles, quartiles, deciles and other diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java new file mode 100644 index 000000000000..9d0feb4e5bd3 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class JavaDataFrameSuite { + private transient JavaSparkContext sc; + private transient HiveContext hc; + + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); + } + df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + hc.sql("DROP TABLE IF EXISTS window_table"); + } + + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select( + functions.avg("key").over( + Window$.MODULE$.partitionBy("value") + .orderBy("key") + .rowsBetween(Frame.preceding(1), Frame.following(1)))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 309aab31b45d..29661cb8a508 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{Frame, Window, Row, QueryTest} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -26,7 +26,7 @@ class HiveDataFrameWindowSuite extends QueryTest { test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = partitionBy("key").orderBy("value") + val w = Window.partitionBy("key").orderBy("value") checkAnswer( df.select( @@ -37,7 +37,7 @@ class HiveDataFrameWindowSuite extends QueryTest { test("reuse window orderBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = orderBy("value").partitionBy("key") + val w = Window.orderBy("value").partitionBy("key") checkAnswer( df.select( @@ -53,7 +53,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lead("value").over( - partitionBy($"key") + Window.partitionBy($"key") .orderBy($"value"))), sql( """SELECT @@ -68,7 +68,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lag("value").over( - partitionBy($"key") + Window.partitionBy($"key") .orderBy($"value"))), sql( """SELECT @@ -76,6 +76,19 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } + test("last in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + last("value").over(Window)), + sql( + """SELECT + | last_value(value) OVER () + | FROM window_table""".stripMargin).collect()) + } + test("lead in window with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -83,7 +96,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lead("value", 2, "n/a").over( - partitionBy("key") + Window.partitionBy("key") .orderBy("value"))), sql( """SELECT @@ -98,7 +111,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( lag("value", 2, "n/a").over( - partitionBy($"key") + Window.partitionBy($"key") .orderBy($"value"))), sql( """SELECT @@ -113,40 +126,40 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", max("key").over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), min("key").over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), mean("key").over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), count("key").over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), sum("key").over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), ntile("key").over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), ntile($"key").over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), rowNumber().over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), denseRank().over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), rank().over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), cumeDist().over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key")), percentRank().over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key"))), sql( s"""SELECT @@ -172,13 +185,9 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( avg("key").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .rows - .between - .preceding(1) - .and - .following(1))), + .rowsBetween(Frame.preceding(1), Frame.following(1)))), sql( """SELECT | avg(key) OVER @@ -192,13 +201,9 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( avg("key").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .range - .between - .preceding(1) - .and - .following(1))), + .rangeBetween(Frame.preceding(1), Frame.following(1)))), sql( """SELECT | avg(key) OVER @@ -213,18 +218,13 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", first("value").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .rows - .preceding(1)), + .rowsBetween(Frame.preceding(1), Frame.currentRow)), first("value").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .rows - .between - .preceding(2) - .and - .preceding(1))), + .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), sql( """SELECT | key, @@ -242,29 +242,17 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", last("value").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .rows - .between - .currentRow() - .and - .unboundedFollowing()), + .rowsBetween(Frame.currentRow, Frame.unbounded)), last("value").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .rows - .between - .unboundedPreceding() - .and - .currentRow()), + .rowsBetween(Frame.unbounded, Frame.currentRow)), last("value").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .rows - .between - .preceding(1) - .and - .following(1))), + .rowsBetween(Frame.preceding(1), Frame.following(1)))), sql( """SELECT | key, @@ -283,26 +271,17 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( avg("key").over( - partitionBy($"key") + Window.partitionBy($"key") .orderBy($"value") - .rows - .preceding(1)), + .rowsBetween(Frame.preceding(1), Frame.currentRow)), avg("key").over( - partitionBy($"key") + Window.partitionBy($"key") .orderBy($"value") - .rows - .between - .currentRow - .and - .currentRow), + .rowsBetween(Frame.currentRow, Frame.currentRow)), avg("key").over( - partitionBy($"key") + Window.partitionBy($"key") .orderBy($"value") - .rows - .between - .preceding(2) - .and - .preceding(1))), + .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), sql( """SELECT | avg(key) OVER @@ -321,41 +300,28 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", last("value").over( - partitionBy($"value") + Window.partitionBy($"value") .orderBy($"key") - .range - .following(1)) + .rangeBetween(Frame.following(1), Frame.unbounded)) .equalTo("2") .as("last_v"), avg("key") .over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key") - .range - .between - .preceding(2) - .and - .following(1)) + .rangeBetween(Frame.preceding(2), Frame.following(1))) .as("avg_key1"), avg("key") .over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key") - .range - .between - .currentRow - .and - .following(1)) + .rangeBetween(Frame.currentRow, Frame.following(1))) .as("avg_key2"), avg("key") .over( - partitionBy("value") + Window.partitionBy("value") .orderBy("key") - .range - .between - .preceding(1) - .and - .currentRow) + .rangeBetween(Frame.preceding(1), Frame.currentRow)) .as("avg_key3") ), sql( From 9331605f82cc370e215bfc93f7f41be7760bc324 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 21 May 2015 19:09:46 -0700 Subject: [PATCH 13/16] Refactored API. --- .../scala/org/apache/spark/sql/Column.scala | 15 +- .../org/apache/spark/sql/DataFrame.scala | 9 +- .../scala/org/apache/spark/sql/Window.scala | 222 ------------------ .../apache/spark/sql/expressions/Window.scala | 81 +++++++ .../spark/sql/expressions/WindowSpec.scala | 175 ++++++++++++++ .../spark/sql/hive/JavaDataFrameSuite.java | 96 ++++---- .../hive/JavaMetastoreDataSourcesSuite.java | 1 + .../sql/hive/HiveDataFrameWindowSuite.scala | 146 +++--------- 8 files changed, 344 insertions(+), 401 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Window.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d9a715fd645f..6895aa101095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -890,19 +890,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) /** - * Define a [[Window]] column. + * Define a windowing column. + * * {{{ * val w = Window.partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w.range.preceding(2)), - * avg("price").over(w.range.preceding(4)), - * avg("price").over(partitionBy("name").orderBy("id).range.preceding(1)) + * sum("price").over(w.rangeBetween(Long.MinValue, 2)), + * avg("price").over(w.rowsBetween(0, 4)) * ) * }}} * * @group expr_ops + * @since 1.4.0 */ - def over(w: Window): Column = w.newColumn(this).toColumn + def over(window: expressions.WindowSpec): Column = window.withAggregate(this) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d78b4c2f8909..3ec1c4a2f102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, Unresol import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.sources.CreateTableUsingAsSelect @@ -411,7 +411,7 @@ class DataFrame private[sql]( joined.left, joined.right, joinType = Inner, - Some(expressions.EqualTo( + Some(catalyst.expressions.EqualTo( joined.left.resolve(usingColumn), joined.right.resolve(usingColumn)))) ) @@ -480,8 +480,9 @@ class DataFrame private[sql]( // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. val cond = plan.condition.map { _.transform { - case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) }} plan.copy(condition = cond) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/Window.scala deleted file mode 100644 index 80272f380b11..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/Window.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.language.implicitConversions - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.expressions._ - - -sealed private[sql] class Frame(private[sql] var boundary: FrameBoundary = null) - -/** - * :: Experimental :: - * An utility to specify the Window Frame Range. - */ -object Frame { - val currentRow: Frame = new Frame(CurrentRow) - val unbounded: Frame = new Frame() - def preceding(n: Int): Frame = if (n == 0) { - new Frame(CurrentRow) - } else { - new Frame(ValuePreceding(n)) - } - - def following(n: Int): Frame = if (n == 0) { - new Frame(CurrentRow) - } else { - new Frame(ValueFollowing(n)) - } -} - -/** - * :: Experimental :: - * A Window object with everything unset. But can build new Window object - * based on it. - */ -@Experimental -object Window extends Window() - -/** - * :: Experimental :: - * A set of methods for window function definition for aggregate expressions. - * For example: - * {{{ - * // predefine a window - * val w = Window.partitionBy("name").orderBy("id") - * .rowsBetween(Frame.unbounded, Frame.currentRow) - * df.select( - * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") - * .rowsBetween(Frame.unbounded, Frame.currentRow)) - * ) - * - * df.select( - * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") - * .rowsBetween(Frame.preceding(50), Frame.following(10))) - * ) - * - * }}} - * - */ -@Experimental -class Window { - private var column: Column = _ - private var partitionSpec: Seq[Expression] = Nil - private var orderSpec: Seq[SortOrder] = Nil - private var frame: WindowFrame = UnspecifiedFrame - - private def this( - column: Column = null, - partitionSpec: Seq[Expression] = Nil, - orderSpec: Seq[SortOrder] = Nil, - frame: WindowFrame = UnspecifiedFrame) { - this() - this.column = column - this.partitionSpec = partitionSpec - this.orderSpec = orderSpec - this.frame = frame - } - - private[sql] def newColumn(c: Column): Window = { - new Window(c, partitionSpec, orderSpec, frame) - } - - /** - * Returns a new [[Window]] partitioned by the specified column. - * {{{ - * // The following 2 are equivalent - * df.over(Window.partitionBy("k1", "k2", ...)) - * df.over(Window.partitionBy($"K1", $"k2", ...)) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(colName: String, colNames: String*): Window = { - partitionBy((colName +: colNames).map(Column(_)): _*) - } - - /** - * Returns a new [[Window]] partitioned by the specified column. For example: - * {{{ - * df.over(Window.partitionBy($"col1", $"col2")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(cols: Column*): Window = { - new Window(column, cols.map(_.expr), orderSpec, frame) - } - - /** - * Returns a new [[Window]] sorted by the specified column within - * the partition. - * {{{ - * // The following 2 are equivalent - * df.over(Window.partitionBy("k1").orderBy("k2", "k3")) - * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def orderBy(colName: String, colNames: String*): Window = { - orderBy((colName +: colNames).map(Column(_)): _*) - } - - /** - * Returns a new [[Window]] sorted by the specified column within - * the partition. For example - * {{{ - * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def orderBy(cols: Column*): Window = { - val sortOrder: Seq[SortOrder] = cols.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - new Window(column, partitionSpec, sortOrder, frame) - } - - def rowsBetween(start: Frame, end: Frame): Window = { - assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") - assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") - - val s = if (start.boundary == null) UnboundedPreceding else start.boundary - val e = if (end.boundary == null) UnboundedFollowing else end.boundary - - new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RowFrame, s, e)) - } - - def rangeBetween(start: Frame, end: Frame): Window = { - assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") - assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") - - val s = if (start.boundary == null) UnboundedPreceding else start.boundary - val e = if (end.boundary == null) UnboundedFollowing else end.boundary - - new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RangeFrame, s, e)) - } - - /** - * Convert the window definition into a Column object. - * @group window_funcs - */ - private[sql] def toColumn: Column = { - if (column == null) { - throw new AnalysisException("Window didn't bind with expression") - } - val windowExpr = column.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"We don't support $x in window operation.") - } - new Column(windowExpr) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala new file mode 100644 index 000000000000..d4003b2d9cbf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +object Window { + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + spec.partitionBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + spec.partitionBy(cols : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + spec.orderBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + spec.orderBy(cols : _*) + } + + private def spec: WindowSpec = { + new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala new file mode 100644 index 000000000000..00ecdb47ca5a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.catalyst.expressions._ + + +/** + * :: Experimental :: + * A window specification that defines the partitioning, ordering, and frame boundaries. + * + * Use the static methods in [[Window]] to create a [[WindowSpec]]. + * + * @since 1.4.0 + */ +@Experimental +class WindowSpec private[sql]( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + frame: catalyst.expressions.WindowFrame) { + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + new WindowSpec(cols.map(_.expr), orderSpec, frame) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new WindowSpec(partitionSpec, sortOrder, frame) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rowsBetween(start: Long, end: Long): WindowSpec = { + between(RowFrame, start, end) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rangeBetween(start: Long, end: Long): WindowSpec = { + between(RangeFrame, start, end) + } + + private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + val boundaryEnd = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedFollowing + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + } + + /** + * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. + */ + private[sql] def withAggregate(aggregate: Column): Column = { + val windowExpr = aggregate.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in window operation.") + } + new Column(windowExpr) + } + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 9d0feb4e5bd3..eeb676d3dc12 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -14,74 +14,64 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.hive; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; -import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.util.Utils; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.hive.test.TestHive$; public class JavaDataFrameSuite { - private transient JavaSparkContext sc; - private transient HiveContext hc; + private transient JavaSparkContext sc; + private transient HiveContext hc; - DataFrame df; + DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { - String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); - if (errorMessage != null) { - Assert.fail(errorMessage); - } + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); } + } - @Before - public void setUp() throws IOException { - hc = TestHive$.MODULE$; - sc = new JavaSparkContext(hc.sparkContext()); + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); - for (int i = 0; i < 10; i++) { - jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); - } - df = hc.jsonRDD(sc.parallelize(jsonObjects)); - df.registerTempTable("window_table"); + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } + df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } - @After - public void tearDown() throws IOException { - // Clean up tables. - hc.sql("DROP TABLE IF EXISTS window_table"); - } + @After + public void tearDown() throws IOException { + // Clean up tables. + hc.sql("DROP TABLE IF EXISTS window_table"); + } - @Test - public void saveTableAndQueryIt() { - checkAnswer( - df.select( - functions.avg("key").over( - Window$.MODULE$.partitionBy("value") - .orderBy("key") - .rowsBetween(Frame.preceding(1), Frame.following(1)))), - hc.sql("SELECT avg(key) " + - "OVER (PARTITION BY value " + - " ORDER BY key " + - " ROWS BETWEEN 1 preceding and 1 following) " + - "FROM window_table").collectAsList()); - } + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select(functions.avg("key").over( + Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 58fe96adab17..ee21caf63667 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.hive; import java.io.File; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 29661cb8a508..6fee3bcb1735 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Frame, Window, Row, QueryTest} +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -52,9 +53,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( - lead("value").over( - Window.partitionBy($"key") - .orderBy($"value"))), + lead("value").over(Window.partitionBy($"key").orderBy($"value"))), sql( """SELECT | lead(value) OVER (PARTITION BY key ORDER BY value) @@ -76,28 +75,13 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } - test("last in window with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - last("value").over(Window)), - sql( - """SELECT - | last_value(value) OVER () - | FROM window_table""".stripMargin).collect()) - } - test("lead in window with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( - lead("value", 2, "n/a").over( - Window.partitionBy("key") - .orderBy("value"))), + lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), sql( """SELECT | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) @@ -110,9 +94,7 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - lag("value", 2, "n/a").over( - Window.partitionBy($"key") - .orderBy($"value"))), + lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), sql( """SELECT | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) @@ -125,42 +107,18 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - max("key").over( - Window.partitionBy("value") - .orderBy("key")), - min("key").over( - Window.partitionBy("value") - .orderBy("key")), - mean("key").over( - Window.partitionBy("value") - .orderBy("key")), - count("key").over( - Window.partitionBy("value") - .orderBy("key")), - sum("key").over( - Window.partitionBy("value") - .orderBy("key")), - ntile("key").over( - Window.partitionBy("value") - .orderBy("key")), - ntile($"key").over( - Window.partitionBy("value") - .orderBy("key")), - rowNumber().over( - Window.partitionBy("value") - .orderBy("key")), - denseRank().over( - Window.partitionBy("value") - .orderBy("key")), - rank().over( - Window.partitionBy("value") - .orderBy("key")), - cumeDist().over( - Window.partitionBy("value") - .orderBy("key")), - percentRank().over( - Window.partitionBy("value") - .orderBy("key"))), + max("key").over(Window.partitionBy("value").orderBy("key")), + min("key").over(Window.partitionBy("value").orderBy("key")), + mean("key").over(Window.partitionBy("value").orderBy("key")), + count("key").over(Window.partitionBy("value").orderBy("key")), + sum("key").over(Window.partitionBy("value").orderBy("key")), + ntile("key").over(Window.partitionBy("value").orderBy("key")), + ntile($"key").over(Window.partitionBy("value").orderBy("key")), + rowNumber().over(Window.partitionBy("value").orderBy("key")), + denseRank().over(Window.partitionBy("value").orderBy("key")), + rank().over(Window.partitionBy("value").orderBy("key")), + cumeDist().over(Window.partitionBy("value").orderBy("key")), + percentRank().over(Window.partitionBy("value").orderBy("key"))), sql( s"""SELECT |key, @@ -184,10 +142,7 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - avg("key").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(1), Frame.following(1)))), + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), sql( """SELECT | avg(key) OVER @@ -200,10 +155,7 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - avg("key").over( - Window.partitionBy($"value") - .orderBy($"key") - .rangeBetween(Frame.preceding(1), Frame.following(1)))), + avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), sql( """SELECT | avg(key) OVER @@ -217,14 +169,8 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - first("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(1), Frame.currentRow)), - first("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), + first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 0)), + first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-2, 1))), sql( """SELECT | key, @@ -242,17 +188,10 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.currentRow, Frame.unbounded)), + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.unbounded, Frame.currentRow)), - last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(1), Frame.following(1)))), + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), sql( """SELECT | key, @@ -270,18 +209,9 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - avg("key").over( - Window.partitionBy($"key") - .orderBy($"value") - .rowsBetween(Frame.preceding(1), Frame.currentRow)), - avg("key").over( - Window.partitionBy($"key") - .orderBy($"value") - .rowsBetween(Frame.currentRow, Frame.currentRow)), - avg("key").over( - Window.partitionBy($"key") - .orderBy($"value") - .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), + avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-1, 0)), + avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(0, 0)), + avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-2, 1))), sql( """SELECT | avg(key) OVER @@ -300,28 +230,14 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rangeBetween(Frame.following(1), Frame.unbounded)) + Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) .equalTo("2") .as("last_v"), - avg("key") - .over( - Window.partitionBy("value") - .orderBy("key") - .rangeBetween(Frame.preceding(2), Frame.following(1))) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-2, 1)) .as("avg_key1"), - avg("key") - .over( - Window.partitionBy("value") - .orderBy("key") - .rangeBetween(Frame.currentRow, Frame.following(1))) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, 1)) .as("avg_key2"), - avg("key") - .over( - Window.partitionBy("value") - .orderBy("key") - .rangeBetween(Frame.preceding(1), Frame.currentRow)) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), sql( From 9794d9dffb1d6a514c5f624d359ecb53e4f3a10c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 21 May 2015 19:11:53 -0700 Subject: [PATCH 14/16] Moved Java test package. --- .../org/apache/spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java | 2 +- .../org/apache/spark/sql/hive/execution/UDFIntegerToString.java | 0 .../org/apache/spark/sql/hive/execution/UDFListListInt.java | 0 .../org/apache/spark/sql/hive/execution/UDFListString.java | 0 .../org/apache/spark/sql/hive/execution/UDFStringString.java | 0 .../org/apache/spark/sql/hive/execution/UDFTwoListList.java | 0 7 files changed, 2 insertions(+), 2 deletions(-) rename sql/hive/src/test/java/{ => test}/org/apache/spark/sql/hive/JavaDataFrameSuite.java (98%) rename sql/hive/src/test/java/{ => test}/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java (99%) rename sql/hive/src/test/java/{ => test}/org/apache/spark/sql/hive/execution/UDFIntegerToString.java (100%) rename sql/hive/src/test/java/{ => test}/org/apache/spark/sql/hive/execution/UDFListListInt.java (100%) rename sql/hive/src/test/java/{ => test}/org/apache/spark/sql/hive/execution/UDFListString.java (100%) rename sql/hive/src/test/java/{ => test}/org/apache/spark/sql/hive/execution/UDFStringString.java (100%) rename sql/hive/src/test/java/{ => test}/org/apache/spark/sql/hive/execution/UDFTwoListList.java (100%) diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java similarity index 98% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index eeb676d3dc12..1ed0facaeeb2 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive; +package test.org.apache.spark.sql.hive; import java.io.IOException; import java.util.ArrayList; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java similarity index 99% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index ee21caf63667..4cf67b5c652e 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive; +package test.org.apache.spark.sql.hive; import java.io.File; import java.io.IOException; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java From dc448fe2334684d0354c32110edac59b58d1adf8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 21 May 2015 19:13:11 -0700 Subject: [PATCH 15/16] Fixed Hive tests. --- .../java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java | 1 + .../org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java | 1 + 2 files changed, 2 insertions(+) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 1ed0facaeeb2..c4828c471764 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; public class JavaDataFrameSuite { diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 4cf67b5c652e..64d1ce92931e 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -37,6 +37,7 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; From 026d587835d5195df06219799b38c506fffbfea3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 21 May 2015 22:51:09 -0700 Subject: [PATCH 16/16] Address code review feedback. --- .../spark/sql/expressions/WindowSpec.scala | 8 +-- .../org/apache/spark/sql/functions.scala | 19 +++++ .../sql/hive/HiveDataFrameWindowSuite.scala | 69 +++++-------------- 3 files changed, 39 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 00ecdb47ca5a..c3d224629702 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -122,11 +122,11 @@ class WindowSpec private[sql]( case x if x > 0 => ValueFollowing(start.toInt) } - val boundaryEnd = start match { + val boundaryEnd = end match { case 0 => CurrentRow - case Long.MinValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-start.toInt) - case x if x > 0 => ValueFollowing(start.toInt) + case Long.MaxValue => UnboundedFollowing + case x if x < 0 => ValuePreceding(-end.toInt) + case x if x > 0 => ValueFollowing(end.toInt) } new WindowSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9188d9719242..8775be724e0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -330,6 +330,7 @@ object functions { * null when the current row extends before the beginning of the window. * * @group window_funcs + * @since 1.4.0 */ def lag(columnName: String): Column = { lag(columnName, 1) @@ -340,6 +341,7 @@ object functions { * null when the current row extends before the beginning of the window. * * @group window_funcs + * @since 1.4.0 */ def lag(e: Column): Column = { lag(e, 1) @@ -350,6 +352,7 @@ object functions { * null when the current row extends before the beginning of the window. * * @group window_funcs + * @since 1.4.0 */ def lag(e: Column, count: Int): Column = { lag(e, count, null) @@ -360,6 +363,7 @@ object functions { * null when the current row extends before the beginning of the window. * * @group window_funcs + * @since 1.4.0 */ def lag(columnName: String, count: Int): Column = { lag(columnName, count, null) @@ -371,6 +375,7 @@ object functions { * of the window. * * @group window_funcs + * @since 1.4.0 */ def lag(columnName: String, count: Int, defaultValue: Any): Column = { lag(Column(columnName), count, defaultValue) @@ -382,6 +387,7 @@ object functions { * of the window. * * @group window_funcs + * @since 1.4.0 */ def lag(e: Column, count: Int, defaultValue: Any): Column = { UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) @@ -392,6 +398,7 @@ object functions { * null when the current row extends before the end of the window. * * @group window_funcs + * @since 1.4.0 */ def lead(columnName: String): Column = { lead(columnName, 1) @@ -402,6 +409,7 @@ object functions { * null when the current row extends before the end of the window. * * @group window_funcs + * @since 1.4.0 */ def lead(e: Column): Column = { lead(e, 1) @@ -412,6 +420,7 @@ object functions { * null when the current row extends before the end of the window. * * @group window_funcs + * @since 1.4.0 */ def lead(columnName: String, count: Int): Column = { lead(columnName, count, null) @@ -422,6 +431,7 @@ object functions { * null when the current row extends before the end of the window. * * @group window_funcs + * @since 1.4.0 */ def lead(e: Column, count: Int): Column = { lead(e, count, null) @@ -432,6 +442,7 @@ object functions { * given default value when the current row extends before the end of the window. * * @group window_funcs + * @since 1.4.0 */ def lead(columnName: String, count: Int, defaultValue: Any): Column = { lead(Column(columnName), count, defaultValue) @@ -442,6 +453,7 @@ object functions { * given default value when the current row extends before the end of the window. * * @group window_funcs + * @since 1.4.0 */ def lead(e: Column, count: Int, defaultValue: Any): Column = { UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) @@ -454,6 +466,7 @@ object functions { * number of groups called buckets and assigns a bucket number to each row in the partition. * * @group window_funcs + * @since 1.4.0 */ def ntile(e: Column): Column = { UnresolvedWindowFunction("ntile", e.expr :: Nil) @@ -466,6 +479,7 @@ object functions { * number of groups called buckets and assigns a bucket number to each row in the partition. * * @group window_funcs + * @since 1.4.0 */ def ntile(columnName: String): Column = { ntile(Column(columnName)) @@ -476,6 +490,7 @@ object functions { * row within the partition. * * @group window_funcs + * @since 1.4.0 */ def rowNumber(): Column = { UnresolvedWindowFunction("row_number", Nil) @@ -488,6 +503,7 @@ object functions { * place and that the next person came in third. * * @group window_funcs + * @since 1.4.0 */ def denseRank(): Column = { UnresolvedWindowFunction("dense_rank", Nil) @@ -500,6 +516,7 @@ object functions { * place and that the next person came in third. * * @group window_funcs + * @since 1.4.0 */ def rank(): Column = { UnresolvedWindowFunction("rank", Nil) @@ -512,6 +529,7 @@ object functions { * CUME_DIST(x) = number of values in S coming before and including x in the specified order / N * * @group window_funcs + * @since 1.4.0 */ def cumeDist(): Column = { UnresolvedWindowFunction("cume_dist", Nil) @@ -524,6 +542,7 @@ object functions { * (rank of row in its partition - 1) / (number of rows in the partition - 1) * * @group window_funcs + * @since 1.4.0 */ def percentRank(): Column = { UnresolvedWindowFunction("percent_rank", Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 6fee3bcb1735..6cea6776c8ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -47,7 +47,7 @@ class HiveDataFrameWindowSuite extends QueryTest { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } - test("lead in window") { + test("lead") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") @@ -60,7 +60,7 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } - test("lag in window") { + test("lag") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") @@ -75,7 +75,7 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } - test("lead in window with default value") { + test("lead with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") @@ -88,7 +88,7 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } - test("lag in window with default value") { + test("lag with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") @@ -134,23 +134,23 @@ class HiveDataFrameWindowSuite extends QueryTest { |rank() over (partition by value order by key), |cume_dist() over (partition by value order by key), |percent_rank() over (partition by value order by key) - |FROM window_table""".stripMargin).collect) + |FROM window_table""".stripMargin).collect()) } - test("aggregation in a row window") { + test("aggregation and rows between") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), sql( """SELECT | avg(key) OVER - | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 1 following) + | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) | FROM window_table""".stripMargin).collect()) } - test("aggregation in a Range window") { + test("aggregation and range betweens") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( @@ -163,25 +163,7 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } - test("Aggregate function in Row preceding Window") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "3"), (4, "3")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - $"key", - first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 0)), - first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-2, 1))), - sql( - """SELECT - | key, - | first_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS 1 preceding), - | first_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between 2 preceding and 1 preceding) - | FROM window_table""".stripMargin).collect()) - } - - test("Aggregate function in Row following Window") { + test("aggregation and rows betweens with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( @@ -191,7 +173,7 @@ class HiveDataFrameWindowSuite extends QueryTest { Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), last("value").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), - last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), + last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), sql( """SELECT | key, @@ -204,26 +186,7 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } - test("Multiple aggregate functions in row window") { - val df = Seq((1, "1"), (1, "2"), (3, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-1, 0)), - avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(0, 0)), - avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-2, 1))), - sql( - """SELECT - | avg(key) OVER - | (partition by key ORDER BY value rows 1 preceding), - | avg(key) OVER - | (partition by key ORDER BY value rows between current row and current row), - | avg(key) OVER - | (partition by key ORDER BY value rows between 2 preceding and 1 preceding) - | FROM window_table""".stripMargin).collect()) - } - - test("Multiple aggregate functions in range window") { + test("aggregation and range betweens with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( @@ -233,9 +196,9 @@ class HiveDataFrameWindowSuite extends QueryTest { Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) .equalTo("2") .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-2, 1)) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, 1)) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) .as("avg_key2"), avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") @@ -246,9 +209,9 @@ class HiveDataFrameWindowSuite extends QueryTest { | last_value(value) OVER | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 following), + | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and 1 following), + | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), | avg(key) OVER | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) | FROM window_table""".stripMargin).collect())