diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index dc0aeea7c4ae..d9a715fd645f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -889,6 +889,21 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + /** + * Define a [[Window]] column. + * {{{ + * val w = Window.partitionBy("name").orderBy("id") + * df.select( + * sum("price").over(w.range.preceding(2)), + * avg("price").over(w.range.preceding(4)), + * avg("price").over(partitionBy("name").orderBy("id).range.preceding(1)) + * ) + * }}} + * + * @group expr_ops + */ + def over(w: Window): Column = w.newColumn(this).toColumn + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/Window.scala new file mode 100644 index 000000000000..80272f380b11 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Window.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ + + +sealed private[sql] class Frame(private[sql] var boundary: FrameBoundary = null) + +/** + * :: Experimental :: + * An utility to specify the Window Frame Range. + */ +object Frame { + val currentRow: Frame = new Frame(CurrentRow) + val unbounded: Frame = new Frame() + def preceding(n: Int): Frame = if (n == 0) { + new Frame(CurrentRow) + } else { + new Frame(ValuePreceding(n)) + } + + def following(n: Int): Frame = if (n == 0) { + new Frame(CurrentRow) + } else { + new Frame(ValueFollowing(n)) + } +} + +/** + * :: Experimental :: + * A Window object with everything unset. But can build new Window object + * based on it. + */ +@Experimental +object Window extends Window() + +/** + * :: Experimental :: + * A set of methods for window function definition for aggregate expressions. + * For example: + * {{{ + * // predefine a window + * val w = Window.partitionBy("name").orderBy("id") + * .rowsBetween(Frame.unbounded, Frame.currentRow) + * df.select( + * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") + * .rowsBetween(Frame.unbounded, Frame.currentRow)) + * ) + * + * df.select( + * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") + * .rowsBetween(Frame.preceding(50), Frame.following(10))) + * ) + * + * }}} + * + */ +@Experimental +class Window { + private var column: Column = _ + private var partitionSpec: Seq[Expression] = Nil + private var orderSpec: Seq[SortOrder] = Nil + private var frame: WindowFrame = UnspecifiedFrame + + private def this( + column: Column = null, + partitionSpec: Seq[Expression] = Nil, + orderSpec: Seq[SortOrder] = Nil, + frame: WindowFrame = UnspecifiedFrame) { + this() + this.column = column + this.partitionSpec = partitionSpec + this.orderSpec = orderSpec + this.frame = frame + } + + private[sql] def newColumn(c: Column): Window = { + new Window(c, partitionSpec, orderSpec, frame) + } + + /** + * Returns a new [[Window]] partitioned by the specified column. + * {{{ + * // The following 2 are equivalent + * df.over(Window.partitionBy("k1", "k2", ...)) + * df.over(Window.partitionBy($"K1", $"k2", ...)) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): Window = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Returns a new [[Window]] partitioned by the specified column. For example: + * {{{ + * df.over(Window.partitionBy($"col1", $"col2")) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): Window = { + new Window(column, cols.map(_.expr), orderSpec, frame) + } + + /** + * Returns a new [[Window]] sorted by the specified column within + * the partition. + * {{{ + * // The following 2 are equivalent + * df.over(Window.partitionBy("k1").orderBy("k2", "k3")) + * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): Window = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Returns a new [[Window]] sorted by the specified column within + * the partition. For example + * {{{ + * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def orderBy(cols: Column*): Window = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new Window(column, partitionSpec, sortOrder, frame) + } + + def rowsBetween(start: Frame, end: Frame): Window = { + assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") + assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") + + val s = if (start.boundary == null) UnboundedPreceding else start.boundary + val e = if (end.boundary == null) UnboundedFollowing else end.boundary + + new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RowFrame, s, e)) + } + + def rangeBetween(start: Frame, end: Frame): Window = { + assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") + assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") + + val s = if (start.boundary == null) UnboundedPreceding else start.boundary + val e = if (end.boundary == null) UnboundedFollowing else end.boundary + + new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RangeFrame, s, e)) + } + + /** + * Convert the window definition into a Column object. + * @group window_funcs + */ + private[sql] def toColumn: Column = { + if (column == null) { + throw new AnalysisException("Window didn't bind with expression") + } + val windowExpr = column.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"We don't support $x in window operation.") + } + new Column(windowExpr) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6640631cf071..9188d9719242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.Utils * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname window_funcs Window functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -320,6 +321,214 @@ object functions { */ def max(columnName: String): Column = max(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Window functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Window function: returns the lag value of current row of the expression, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(columnName: String): Column = { + lag(columnName, 1) + } + + /** + * Window function: returns the lag value of current row of the column, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(e: Column): Column = { + lag(e, 1) + } + + /** + * Window function: returns the lag values of current row of the expression, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(e: Column, count: Int): Column = { + lag(e, count, null) + } + + /** + * Window function: returns the lag values of current row of the column, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(columnName: String, count: Int): Column = { + lag(columnName, count, null) + } + + /** + * Window function: returns the lag values of current row of the column, + * given default value when the current row extends before the beginning + * of the window. + * + * @group window_funcs + */ + def lag(columnName: String, count: Int, defaultValue: Any): Column = { + lag(Column(columnName), count, defaultValue) + } + + /** + * Window function: returns the lag values of current row of the expression, + * given default value when the current row extends before the beginning + * of the window. + * + * @group window_funcs + */ + def lag(e: Column, count: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the lead value of current row of the column, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String): Column = { + lead(columnName, 1) + } + + /** + * Window function: returns the lead value of current row of the expression, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column): Column = { + lead(e, 1) + } + + /** + * Window function: returns the lead values of current row of the column, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String, count: Int): Column = { + lead(columnName, count, null) + } + + /** + * Window function: returns the lead values of current row of the expression, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column, count: Int): Column = { + lead(e, count, null) + } + + /** + * Window function: returns the lead values of current row of the column, + * given default value when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String, count: Int, defaultValue: Any): Column = { + lead(Column(columnName), count, defaultValue) + } + + /** + * Window function: returns the lead values of current row of the expression, + * given default value when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column, count: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) + } + + /** + * NTILE for specified expression. + * NTILE allows easy calculation of tertiles, quartiles, deciles and other + * common summary statistics. This function divides an ordered partition into a specified + * number of groups called buckets and assigns a bucket number to each row in the partition. + * + * @group window_funcs + */ + def ntile(e: Column): Column = { + UnresolvedWindowFunction("ntile", e.expr :: Nil) + } + + /** + * NTILE for specified column. + * NTILE allows easy calculation of tertiles, quartiles, deciles and other + * common summary statistics. This function divides an ordered partition into a specified + * number of groups called buckets and assigns a bucket number to each row in the partition. + * + * @group window_funcs + */ + def ntile(columnName: String): Column = { + ntile(Column(columnName)) + } + + /** + * Assigns a unique number (sequentially, starting from 1, as defined by ORDER BY) to each + * row within the partition. + * + * @group window_funcs + */ + def rowNumber(): Column = { + UnresolvedWindowFunction("row_number", Nil) + } + + /** + * The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * @group window_funcs + */ + def denseRank(): Column = { + UnresolvedWindowFunction("dense_rank", Nil) + } + + /** + * The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * @group window_funcs + */ + def rank(): Column = { + UnresolvedWindowFunction("rank", Nil) + } + + /** + * CUME_DIST (defined as the inverse of percentile in some statistical books) computes + * the position of a specified value relative to a set of values. + * To compute the CUME_DIST of a value x in a set S of size N, you use the formula: + * CUME_DIST(x) = number of values in S coming before and including x in the specified order / N + * + * @group window_funcs + */ + def cumeDist(): Column = { + UnresolvedWindowFunction("cume_dist", Nil) + } + + /** + * PERCENT_RANK is similar to CUME_DIST, but it uses rank values rather than row counts + * in its numerator. + * The formula: + * (rank of row in its partition - 1) / (number of rows in the partition - 1) + * + * @group window_funcs + */ + def percentRank(): Column = { + UnresolvedWindowFunction("percent_rank", Nil) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java new file mode 100644 index 000000000000..9d0feb4e5bd3 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class JavaDataFrameSuite { + private transient JavaSparkContext sc; + private transient HiveContext hc; + + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); + } + df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + hc.sql("DROP TABLE IF EXISTS window_table"); + } + + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select( + functions.avg("key").over( + Window$.MODULE$.partitionBy("value") + .orderBy("key") + .rowsBetween(Frame.preceding(1), Frame.following(1)))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala new file mode 100644 index 000000000000..29661cb8a508 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Frame, Window, Row, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +class HiveDataFrameWindowSuite extends QueryTest { + + test("reuse window partitionBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key").over(w), + lead("value").over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.orderBy("value").partitionBy("key") + + checkAnswer( + df.select( + lead("key").over(w), + lead("value").over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead in window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lead("value").over( + Window.partitionBy($"key") + .orderBy($"value"))), + sql( + """SELECT + | lead(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag in window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lag("value").over( + Window.partitionBy($"key") + .orderBy($"value"))), + sql( + """SELECT + | lag(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("last in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + last("value").over(Window)), + sql( + """SELECT + | last_value(value) OVER () + | FROM window_table""".stripMargin).collect()) + } + + test("lead in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lead("value", 2, "n/a").over( + Window.partitionBy("key") + .orderBy("value"))), + sql( + """SELECT + | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lag("value", 2, "n/a").over( + Window.partitionBy($"key") + .orderBy($"value"))), + sql( + """SELECT + | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("rank functions in unspecific window") { + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + max("key").over( + Window.partitionBy("value") + .orderBy("key")), + min("key").over( + Window.partitionBy("value") + .orderBy("key")), + mean("key").over( + Window.partitionBy("value") + .orderBy("key")), + count("key").over( + Window.partitionBy("value") + .orderBy("key")), + sum("key").over( + Window.partitionBy("value") + .orderBy("key")), + ntile("key").over( + Window.partitionBy("value") + .orderBy("key")), + ntile($"key").over( + Window.partitionBy("value") + .orderBy("key")), + rowNumber().over( + Window.partitionBy("value") + .orderBy("key")), + denseRank().over( + Window.partitionBy("value") + .orderBy("key")), + rank().over( + Window.partitionBy("value") + .orderBy("key")), + cumeDist().over( + Window.partitionBy("value") + .orderBy("key")), + percentRank().over( + Window.partitionBy("value") + .orderBy("key"))), + sql( + s"""SELECT + |key, + |max(key) over (partition by value order by key), + |min(key) over (partition by value order by key), + |avg(key) over (partition by value order by key), + |count(key) over (partition by value order by key), + |sum(key) over (partition by value order by key), + |ntile(key) over (partition by value order by key), + |ntile(key) over (partition by value order by key), + |row_number() over (partition by value order by key), + |dense_rank() over (partition by value order by key), + |rank() over (partition by value order by key), + |cume_dist() over (partition by value order by key), + |percent_rank() over (partition by value order by key) + |FROM window_table""".stripMargin).collect) + } + + test("aggregation in a row window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over( + Window.partitionBy($"value") + .orderBy($"key") + .rowsBetween(Frame.preceding(1), Frame.following(1)))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation in a Range window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over( + Window.partitionBy($"value") + .orderBy($"key") + .rangeBetween(Frame.preceding(1), Frame.following(1)))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) + } + + test("Aggregate function in Row preceding Window") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "3"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + first("value").over( + Window.partitionBy($"value") + .orderBy($"key") + .rowsBetween(Frame.preceding(1), Frame.currentRow)), + first("value").over( + Window.partitionBy($"value") + .orderBy($"key") + .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), + sql( + """SELECT + | key, + | first_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS 1 preceding), + | first_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 2 preceding and 1 preceding) + | FROM window_table""".stripMargin).collect()) + } + + test("Aggregate function in Row following Window") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value") + .orderBy($"key") + .rowsBetween(Frame.currentRow, Frame.unbounded)), + last("value").over( + Window.partitionBy($"value") + .orderBy($"key") + .rowsBetween(Frame.unbounded, Frame.currentRow)), + last("value").over( + Window.partitionBy($"value") + .orderBy($"key") + .rowsBetween(Frame.preceding(1), Frame.following(1)))), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) + | FROM window_table""".stripMargin).collect()) + } + + test("Multiple aggregate functions in row window") { + val df = Seq((1, "1"), (1, "2"), (3, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over( + Window.partitionBy($"key") + .orderBy($"value") + .rowsBetween(Frame.preceding(1), Frame.currentRow)), + avg("key").over( + Window.partitionBy($"key") + .orderBy($"value") + .rowsBetween(Frame.currentRow, Frame.currentRow)), + avg("key").over( + Window.partitionBy($"key") + .orderBy($"value") + .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), + sql( + """SELECT + | avg(key) OVER + | (partition by key ORDER BY value rows 1 preceding), + | avg(key) OVER + | (partition by key ORDER BY value rows between current row and current row), + | avg(key) OVER + | (partition by key ORDER BY value rows between 2 preceding and 1 preceding) + | FROM window_table""".stripMargin).collect()) + } + + test("Multiple aggregate functions in range window") { + val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value") + .orderBy($"key") + .rangeBetween(Frame.following(1), Frame.unbounded)) + .equalTo("2") + .as("last_v"), + avg("key") + .over( + Window.partitionBy("value") + .orderBy("key") + .rangeBetween(Frame.preceding(2), Frame.following(1))) + .as("avg_key1"), + avg("key") + .over( + Window.partitionBy("value") + .orderBy("key") + .rangeBetween(Frame.currentRow, Frame.following(1))) + .as("avg_key2"), + avg("key") + .over( + Window.partitionBy("value") + .orderBy("key") + .rangeBetween(Frame.preceding(1), Frame.currentRow)) + .as("avg_key3") + ), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) + | FROM window_table""".stripMargin).collect()) + } +}