Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Analyzer(
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
Expand Down Expand Up @@ -1585,3 +1586,92 @@ object ResolveUpCast extends Rule[LogicalPlan] {
}
}
}

/**
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
* figure out how many windows a time column can map to, we over-estimate the number of windows and
* filter out the rows where the time column is not inside the time window.
*/
object TimeWindowing extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.dsl.expressions._

private final val WINDOW_START = "start"
private final val WINDOW_END = "end"

/**
* Generates the logical plan for generating window ranges on a timestamp column. Without
* knowing what the timestamp value is, it's non-trivial to figure out deterministically how many
* window ranges a timestamp will map to given all possible combinations of a window duration,
* slide duration and start time (offset). Therefore, we express and over-estimate the number of
* windows there may be, and filter the valid windows. We use last Project operator to group
* the window columns into a struct so they can be accessed as `window.start` and `window.end`.
*
* The windows are calculated as below:
* maxNumOverlapping <- ceil(windowDuration / slideDuration)
* for (i <- 0 until maxNumOverlapping)
* windowId <- ceil((timestamp - startTime) / slideDuration)
* windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime
* windowEnd <- windowStart + windowDuration
* return windowStart, windowEnd
*
* This behaves as follows for the given parameters for the time: 12:05. The valid windows are
* marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the
* Filter operator.
* window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m
* 11:55 - 12:07 + 11:52 - 12:04 x
* 12:00 - 12:12 + 11:57 - 12:09 +
* 12:05 - 12:17 + 12:02 - 12:14 +
*
* @param plan The logical plan
* @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val windowExpressions =
p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct.

// Only support a single window expression for now
if (windowExpressions.size == 1 &&
windowExpressions.head.timeColumn.resolved &&
windowExpressions.head.checkInputDataTypes().isSuccess) {
val window = windowExpressions.head
val windowAttr = AttributeReference("window", window.dataType)()

val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
val windows = Seq.tabulate(maxNumOverlapping + 1) { i =>
val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) /
window.slideDuration)
val windowStart = (windowId + i - maxNumOverlapping) *
window.slideDuration + window.startTime
val windowEnd = windowStart + window.windowDuration

CreateNamedStruct(
Literal(WINDOW_START) :: windowStart ::
Literal(WINDOW_END) :: windowEnd :: Nil)
}

val projections = windows.map(_ +: p.children.head.output)

val filterExpr =
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
window.timeColumn < windowAttr.getField(WINDOW_END)

val expandedPlan =
Filter(filterExpr,
Expand(projections, windowAttr +: child.output, child))

val substitutedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}

substitutedPlan.withNewChildren(expandedPlan :: Nil)
} else if (windowExpressions.size > 1) {
p.failAnalysis("Multiple time window expressions would result in a cartesian product " +
"of rows, therefore they are not currently not supported.")
} else {
p // Return unchanged. Analyzer will throw exception later
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ object FunctionRegistry {
expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
expression[TimeWindow]("window"),

// collection functions
expression[ArrayContains]("array_contains"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.lang.StringUtils

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

case class TimeWindow(
timeColumn: Expression,
windowDuration: Long,
slideDuration: Long,
startTime: Long) extends UnaryExpression
with ImplicitCastInputTypes
with Unevaluable
with NonSQLExpression {

override def child: Expression = timeColumn
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def dataType: DataType = new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType))

// This expression is replaced in the analyzer.
override lazy val resolved = false

/**
* Validate the inputs for the window duration, slide duration, and start time in addition to
* the input data type.
*/
override def checkInputDataTypes(): TypeCheckResult = {
val dataTypeCheck = super.checkInputDataTypes()
if (dataTypeCheck.isSuccess) {
if (windowDuration <= 0) {
return TypeCheckFailure(s"The window duration ($windowDuration) must be greater than 0.")
}
if (slideDuration <= 0) {
return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.")
}
if (startTime < 0) {
return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.")
}
if (slideDuration > windowDuration) {
return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" +
s" to the windowDuration ($windowDuration).")
}
if (startTime >= slideDuration) {
return TypeCheckFailure(s"The start time ($startTime) must be less than the " +
s"slideDuration ($slideDuration).")
}
}
dataTypeCheck
}
}

object TimeWindow {
/**
* Parses the interval string for a valid time duration. CalendarInterval expects interval
* strings to start with the string `interval`. For usability, we prepend `interval` to the string
* if the user omitted it.
*
* @param interval The interval string
* @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
if (StringUtils.isBlank(interval)) {
throw new IllegalArgumentException(
"The window duration, slide duration and start time cannot be null or blank.")
}
val intervalString = if (interval.startsWith("interval")) {
interval
} else {
"interval " + interval
}
val cal = CalendarInterval.fromString(intervalString)
if (cal == null) {
throw new IllegalArgumentException(
s"The provided interval ($interval) did not correspond to a valid interval string.")
}
if (cal.months > 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
}
cal.microseconds
}

def apply(
timeColumn: Expression,
windowDuration: String,
slideDuration: String,
startTime: String): TimeWindow = {
TimeWindow(timeColumn,
getIntervalInMicroSeconds(windowDuration),
getIntervalInMicroSeconds(slideDuration),
getIntervalInMicroSeconds(startTime))
}
}

/**
* Expression used internally to convert the TimestampType to Long without losing
* precision, i.e. in microseconds. Used in time windowing.
*/
case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def dataType: DataType = LongType
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val eval = child.gen(ctx)
eval.code +
s"""boolean ${ev.isNull} = ${eval.isNull};
|${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,62 @@ class AnalysisErrorSuite extends AnalysisTest {
testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),
"cannot resolve '`bad_column`'" :: Nil)

errorTest(
"slide duration greater than window in time window",
testRelation2.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")),
s"The slide duration " :: " must be less than or equal to the windowDuration " :: Nil
)

errorTest(
"start time greater than slide duration in time window",
testRelation.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")),
"The start time " :: " must be less than the slideDuration " :: Nil
)

errorTest(
"start time equal to slide duration in time window",
testRelation.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")),
"The start time " :: " must be less than the slideDuration " :: Nil
)

errorTest(
"negative window duration in time window",
testRelation.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")),
"The window duration " :: " must be greater than 0." :: Nil
)

errorTest(
"zero window duration in time window",
testRelation.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")),
"The window duration " :: " must be greater than 0." :: Nil
)

errorTest(
"negative slide duration in time window",
testRelation.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")),
"The slide duration " :: " must be greater than 0." :: Nil
)

errorTest(
"zero slide duration in time window",
testRelation.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")),
"The slide duration" :: " must be greater than 0." :: Nil
)

errorTest(
"negative start time in time window",
testRelation.select(
TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")),
"The start time" :: "must be greater than or equal to 0." :: Nil
)

test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
// Since we manually construct the logical plan at here and Sum only accept
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException

class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper {

test("time window is unevaluable") {
intercept[UnsupportedOperationException] {
evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second"))
}
}

private def checkErrorMessage(msg: String, value: String): Unit = {
val validDuration = "10 second"
val validTime = "5 second"
val e1 = intercept[IllegalArgumentException] {
TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration
}
val e2 = intercept[IllegalArgumentException] {
TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration
}
val e3 = intercept[IllegalArgumentException] {
TimeWindow(Literal(10L), validDuration, validDuration, value).startTime
}
Seq(e1, e2, e3).foreach { e =>
e.getMessage.contains(msg)
}
}

test("blank intervals throw exception") {
for (blank <- Seq(null, " ", "\n", "\t")) {
checkErrorMessage(
"The window duration, slide duration and start time cannot be null or blank.", blank)
}
}

test("invalid intervals throw exception") {
checkErrorMessage(
"did not correspond to a valid interval string.", "2 apples")
}

test("intervals greater than a month throws exception") {
checkErrorMessage(
"Intervals greater than or equal to a month is not supported (1 month).", "1 month")
}

test("interval strings work with and without 'interval' prefix and return microseconds") {
val validDuration = "10 second"
for ((text, seconds) <- Seq(
("1 second", 1000000), // 1e6
("1 minute", 60000000), // 6e7
("2 hours", 7200000000L))) { // 72e9
assert(TimeWindow(Literal(10L), text, validDuration, "0 seconds").windowDuration === seconds)
assert(TimeWindow(Literal(10L), "interval " + text, validDuration, "0 seconds").windowDuration
=== seconds)
}
}
}
Loading