Skip to content

Commit 3313e2a

Browse files
committed
Merge pull request #6104 from chenghao-intel/df_window
[SPARK-7322] [SQL] [WIP] Support Window Function in DataFrame
2 parents 15680ae + d625a64 commit 3313e2a

File tree

5 files changed

+873
-0
lines changed

5 files changed

+873
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,21 @@ class Column(protected[sql] val expr: Expression) extends Logging {
889889
*/
890890
def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr)
891891

892+
/**
893+
* Define a [[Window]] column.
894+
* {{{
895+
* val w = Window.partitionBy("name").orderBy("id")
896+
* df.select(
897+
* sum("price").over(w.range.preceding(2)),
898+
* avg("price").over(w.range.preceding(4)),
899+
* avg("price").over(partitionBy("name").orderBy("id).range.preceding(1))
900+
* )
901+
* }}}
902+
*
903+
* @group expr_ops
904+
*/
905+
def over(w: Window): Column = w.newColumn(this).toColumn
906+
892907
}
893908

894909

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import scala.language.implicitConversions
21+
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.sql.catalyst.expressions._
24+
25+
26+
sealed private[sql] class Frame(private[sql] var boundary: FrameBoundary = null)
27+
28+
/**
29+
* :: Experimental ::
30+
* An utility to specify the Window Frame Range.
31+
*/
32+
object Frame {
33+
val currentRow: Frame = new Frame(CurrentRow)
34+
val unbounded: Frame = new Frame()
35+
def preceding(n: Int): Frame = if (n == 0) {
36+
new Frame(CurrentRow)
37+
} else {
38+
new Frame(ValuePreceding(n))
39+
}
40+
41+
def following(n: Int): Frame = if (n == 0) {
42+
new Frame(CurrentRow)
43+
} else {
44+
new Frame(ValueFollowing(n))
45+
}
46+
}
47+
48+
/**
49+
* :: Experimental ::
50+
* A Window object with everything unset. But can build new Window object
51+
* based on it.
52+
*/
53+
@Experimental
54+
object Window extends Window()
55+
56+
/**
57+
* :: Experimental ::
58+
* A set of methods for window function definition for aggregate expressions.
59+
* For example:
60+
* {{{
61+
* // predefine a window
62+
* val w = Window.partitionBy("name").orderBy("id")
63+
* .rowsBetween(Frame.unbounded, Frame.currentRow)
64+
* df.select(
65+
* avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..")
66+
* .rowsBetween(Frame.unbounded, Frame.currentRow))
67+
* )
68+
*
69+
* df.select(
70+
* avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..")
71+
* .rowsBetween(Frame.preceding(50), Frame.following(10)))
72+
* )
73+
*
74+
* }}}
75+
*
76+
*/
77+
@Experimental
78+
class Window {
79+
private var column: Column = _
80+
private var partitionSpec: Seq[Expression] = Nil
81+
private var orderSpec: Seq[SortOrder] = Nil
82+
private var frame: WindowFrame = UnspecifiedFrame
83+
84+
private def this(
85+
column: Column = null,
86+
partitionSpec: Seq[Expression] = Nil,
87+
orderSpec: Seq[SortOrder] = Nil,
88+
frame: WindowFrame = UnspecifiedFrame) {
89+
this()
90+
this.column = column
91+
this.partitionSpec = partitionSpec
92+
this.orderSpec = orderSpec
93+
this.frame = frame
94+
}
95+
96+
private[sql] def newColumn(c: Column): Window = {
97+
new Window(c, partitionSpec, orderSpec, frame)
98+
}
99+
100+
/**
101+
* Returns a new [[Window]] partitioned by the specified column.
102+
* {{{
103+
* // The following 2 are equivalent
104+
* df.over(Window.partitionBy("k1", "k2", ...))
105+
* df.over(Window.partitionBy($"K1", $"k2", ...))
106+
* }}}
107+
* @group window_funcs
108+
*/
109+
@scala.annotation.varargs
110+
def partitionBy(colName: String, colNames: String*): Window = {
111+
partitionBy((colName +: colNames).map(Column(_)): _*)
112+
}
113+
114+
/**
115+
* Returns a new [[Window]] partitioned by the specified column. For example:
116+
* {{{
117+
* df.over(Window.partitionBy($"col1", $"col2"))
118+
* }}}
119+
* @group window_funcs
120+
*/
121+
@scala.annotation.varargs
122+
def partitionBy(cols: Column*): Window = {
123+
new Window(column, cols.map(_.expr), orderSpec, frame)
124+
}
125+
126+
/**
127+
* Returns a new [[Window]] sorted by the specified column within
128+
* the partition.
129+
* {{{
130+
* // The following 2 are equivalent
131+
* df.over(Window.partitionBy("k1").orderBy("k2", "k3"))
132+
* df.over(Window.partitionBy("k1").orderBy($"k2", $"k3"))
133+
* }}}
134+
* @group window_funcs
135+
*/
136+
@scala.annotation.varargs
137+
def orderBy(colName: String, colNames: String*): Window = {
138+
orderBy((colName +: colNames).map(Column(_)): _*)
139+
}
140+
141+
/**
142+
* Returns a new [[Window]] sorted by the specified column within
143+
* the partition. For example
144+
* {{{
145+
* df.over(Window.partitionBy("k1").orderBy($"k2", $"k3"))
146+
* }}}
147+
* @group window_funcs
148+
*/
149+
@scala.annotation.varargs
150+
def orderBy(cols: Column*): Window = {
151+
val sortOrder: Seq[SortOrder] = cols.map { col =>
152+
col.expr match {
153+
case expr: SortOrder =>
154+
expr
155+
case expr: Expression =>
156+
SortOrder(expr, Ascending)
157+
}
158+
}
159+
new Window(column, partitionSpec, sortOrder, frame)
160+
}
161+
162+
def rowsBetween(start: Frame, end: Frame): Window = {
163+
assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing")
164+
assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding")
165+
166+
val s = if (start.boundary == null) UnboundedPreceding else start.boundary
167+
val e = if (end.boundary == null) UnboundedFollowing else end.boundary
168+
169+
new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RowFrame, s, e))
170+
}
171+
172+
def rangeBetween(start: Frame, end: Frame): Window = {
173+
assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing")
174+
assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding")
175+
176+
val s = if (start.boundary == null) UnboundedPreceding else start.boundary
177+
val e = if (end.boundary == null) UnboundedFollowing else end.boundary
178+
179+
new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RangeFrame, s, e))
180+
}
181+
182+
/**
183+
* Convert the window definition into a Column object.
184+
* @group window_funcs
185+
*/
186+
private[sql] def toColumn: Column = {
187+
if (column == null) {
188+
throw new AnalysisException("Window didn't bind with expression")
189+
}
190+
val windowExpr = column.expr match {
191+
case Average(child) => WindowExpression(
192+
UnresolvedWindowFunction("avg", child :: Nil),
193+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
194+
case Sum(child) => WindowExpression(
195+
UnresolvedWindowFunction("sum", child :: Nil),
196+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
197+
case Count(child) => WindowExpression(
198+
UnresolvedWindowFunction("count", child :: Nil),
199+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
200+
case First(child) => WindowExpression(
201+
// TODO this is a hack for Hive UDAF first_value
202+
UnresolvedWindowFunction("first_value", child :: Nil),
203+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
204+
case Last(child) => WindowExpression(
205+
// TODO this is a hack for Hive UDAF last_value
206+
UnresolvedWindowFunction("last_value", child :: Nil),
207+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
208+
case Min(child) => WindowExpression(
209+
UnresolvedWindowFunction("min", child :: Nil),
210+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
211+
case Max(child) => WindowExpression(
212+
UnresolvedWindowFunction("max", child :: Nil),
213+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
214+
case wf: WindowFunction => WindowExpression(
215+
wf,
216+
WindowSpecDefinition(partitionSpec, orderSpec, frame))
217+
case x =>
218+
throw new UnsupportedOperationException(s"We don't support $x in window operation.")
219+
}
220+
new Column(windowExpr)
221+
}
222+
}

0 commit comments

Comments
 (0)