diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3b20ba5177efd..5017ab5b3646d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1053,6 +1053,55 @@ def to_utc_timestamp(timestamp, tz): return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) +@since(2.0) +@ignore_unicode_prefix +def window(timeColumn, windowDuration, slideDuration=None, startTime=None): + """Bucketize rows into one or more time windows given a timestamp specifying column. Window + starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + the order of months are not supported. + + The time column must be of TimestampType. + + Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid + interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. + If the `slideDuration` is not provided, the windows will be tumbling windows. + + The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start + window intervals. For example, in order to have hourly tumbling windows that start 15 minutes + past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. + + The output column will be a struct called 'window' by default with the nested columns 'start' + and 'end', where 'start' and 'end' will be of `TimestampType`. + + >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + >>> w.select(w.window.start.cast("string").alias("start"), + ... w.window.end.cast("string").alias("end"), "sum").collect() + [Row(start=u'2016-03-11 09:00:05', end=u'2016-03-11 09:00:10', sum=1)] + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(windowDuration, "windowDuration") + if slideDuration and startTime: + check_string_field(slideDuration, "slideDuration") + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration, startTime) + elif slideDuration: + check_string_field(slideDuration, "slideDuration") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration) + elif startTime: + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, windowDuration, startTime) + else: + res = sc._jvm.functions.window(time_col, windowDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ca8db3cbc5993..a8adfd478a63e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -362,7 +362,10 @@ object FunctionRegistry { } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e - case Failure(e) => throw new AnalysisException(e.getMessage) + case Failure(e) => + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + throw new AnalysisException(e.getCause.getMessage) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 8e13833486931..daf3de95dd9ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.lang.StringUtils +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -34,6 +35,28 @@ case class TimeWindow( with Unevaluable with NonSQLExpression { + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this( + timeColumn: Expression, + windowDuration: Expression, + slideDuration: Expression, + startTime: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime)) + } + + def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), 0) + } + + def this(timeColumn: Expression, windowDuration: Expression) = { + this(timeColumn, windowDuration, windowDuration) + } + override def child: Expression = timeColumn override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = new StructType() @@ -104,6 +127,18 @@ object TimeWindow { cal.microseconds } + /** + * Parses the duration expression to generate the long value for the original constructor so + * that we can use `window` in SQL. + */ + private def parseExpression(expr: Expression): Long = expr match { + case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case IntegerLiteral(i) => i.toLong + case NonNullLiteral(l, LongType) => l.toString.toLong + case _ => throw new AnalysisException("The duration and time inputs to window must be " + + "an integer, long or string literal.") + } + def apply( timeColumn: Expression, windowDuration: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 6b7997e903a99..232ca4358865a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.collection.Map import scala.collection.mutable.Stack +import org.apache.commons.lang.ClassUtils import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -365,20 +366,32 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { + // Skip no-arg constructors that are just there for kryo. val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") } - val defaultCtor = ctors.maxBy(_.getParameterTypes.size) + val allArgs: Array[AnyRef] = if (otherCopyArgs.isEmpty) { + newArgs + } else { + newArgs ++ otherCopyArgs + } + val defaultCtor = ctors.find { ctor => + if (ctor.getParameterTypes.length != allArgs.length) { + false + } else if (allArgs.contains(null)) { + // if there is a `null`, we can't figure out the class, therefore we should just fallback + // to older heuristic + false + } else { + val argsArray: Array[Class[_]] = allArgs.map(_.getClass) + ClassUtils.isAssignable(argsArray, ctor.getParameterTypes, true /* autoboxing */) + } + }.getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic try { CurrentOrigin.withOrigin(origin) { - // Skip no-arg constructors that are just there for kryo. - if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType] - } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType] - } + defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] } } catch { case e: java.lang.IllegalArgumentException => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index 71f969aee2ee4..b82cf8d1693e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.PrivateMethodTester + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.LongType -class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { +class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester { test("time window is unevaluable") { intercept[UnsupportedOperationException] { @@ -73,4 +76,36 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { === seconds) } } + + private val parseExpression = PrivateMethod[Long]('parseExpression) + + test("parse sql expression for duration in microseconds - string") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal("5 seconds"))) + assert(dur.isInstanceOf[Long]) + assert(dur === 5000000) + } + + test("parse sql expression for duration in microseconds - integer") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal(100))) + assert(dur.isInstanceOf[Long]) + assert(dur === 100) + } + + test("parse sql expression for duration in microseconds - long") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType))) + assert(dur.isInstanceOf[Long]) + assert(dur === (2 << 52)) + } + + test("parse sql expression for duration in microseconds - invalid interval") { + intercept[IllegalArgumentException] { + TimeWindow.invokePrivate(parseExpression(Literal("2 apples"))) + } + } + + test("parse sql expression for duration in microseconds - invalid expression") { + intercept[AnalysisException] { + TimeWindow.invokePrivate(parseExpression(Rand(123))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 74906050acbb3..b6b01259a3d26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2575,8 +2575,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. @@ -2630,8 +2629,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. @@ -2673,8 +2671,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index e8103a31d5833..06584ec21e2f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -239,4 +239,61 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1)) ) } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").registerTempTable(tableName) + try { + f(tableName) + } finally { + sqlContext.dropTempTable(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with two expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with three expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } }