diff --git a/dataset/src/main/scala/frameless/TypedWindow.scala b/dataset/src/main/scala/frameless/TypedWindow.scala new file mode 100644 index 000000000..8042e7530 --- /dev/null +++ b/dataset/src/main/scala/frameless/TypedWindow.scala @@ -0,0 +1,156 @@ +package frameless + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{ UnspecifiedFrame, WindowFrame } +import org.apache.spark.sql.expressions.{ Window, WindowSpec } +import shapeless.ops.hlist.{ Mapper, ToTraversable } +import shapeless.{ HList, ProductArgs } + +trait OrderedWindow +trait PartitionedWindow + +class TypedWindow[T, A] private ( + partitionSpec: Seq[UntypedExpression[T]], + orderSpec: Seq[UntypedExpression[T]], + frame: WindowFrame //TODO. Really a rows or range between +) { + + def untyped: WindowSpec = Window + .partitionBy(partitionSpec.map(e => new Column(e.expr)):_*) + .orderBy(orderSpec.map(e => new Column(e.expr)):_*) + //TODO: frame + + + def partitionBy[U]( + column: TypedColumn[T, U] + ): TypedWindow[T, A with PartitionedWindow] = + partitionByMany(column) + + def partitionBy[U, V]( + column1: TypedColumn[T, U], + column2: TypedColumn[T, V] + ): TypedWindow[T, A with PartitionedWindow] = + partitionByMany(column1, column2) + + def partitionBy[U, V, W]( + column1: TypedColumn[T, U], + column2: TypedColumn[T, V], + column3: TypedColumn[T,W] + ): TypedWindow[T, A with PartitionedWindow] = + partitionByMany(column1, column2, column3) + + object partitionByMany extends ProductArgs { + def applyProduct[U <: HList](columns: U) + (implicit + i1: ToTraversable.Aux[U, List, TypedColumn[T, _]] + ): TypedWindow[T, A with PartitionedWindow] = { + new TypedWindow[T, A with PartitionedWindow]( + partitionSpec = columns.toList[TypedColumn[T, _]], + orderSpec = orderSpec, + frame = frame + ) + } + } + + def orderBy[U]( + column: SortedTypedColumn[T, U] + ): TypedWindow[T, A with OrderedWindow] = + orderByMany(column) + + def orderBy[U, V]( + column1: SortedTypedColumn[T, U], + column2: SortedTypedColumn[T, V] + ): TypedWindow[T, A with OrderedWindow] = + orderByMany(column1, column2) + + def orderBy[U, V, W]( + column1: SortedTypedColumn[T, U], + column2: SortedTypedColumn[T, V], + column3: SortedTypedColumn[T, W] + ): TypedWindow[T, A with OrderedWindow] = + orderByMany(column1, column2, column3) + + object orderByMany extends ProductArgs { + def applyProduct[U <: HList, O <: HList](columns: U) + (implicit + i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], + i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] + ): TypedWindow[T, A with OrderedWindow] = { + new TypedWindow[T, A with OrderedWindow]( + partitionSpec = partitionSpec, + orderSpec = i0(columns).toList[SortedTypedColumn[T, _]], + frame = frame + ) + } + } +} + +object TypedWindow { + + def orderBy[T]( + column: SortedTypedColumn[T, _] + ): TypedWindow[T, OrderedWindow] = + new orderByManyNew[T].apply(column) //TODO: This is some ugly syntax + + def orderBy[T]( + column1: SortedTypedColumn[T, _], + column2: SortedTypedColumn[T, _] + ): TypedWindow[T, OrderedWindow] = + new orderByManyNew[T].apply(column1, column2) + + def orderBy[T]( + column1: SortedTypedColumn[T, _], + column2: SortedTypedColumn[T, _], + column3: SortedTypedColumn[T, _] + ): TypedWindow[T, OrderedWindow] = + new orderByManyNew[T].apply(column1, column2, column3) + + //Need different name because companion class has `orderByMany` defined as well + //Need a class and not object in order to define what `T` is explicitly. Otherwise it's a mess + //This makes for some pretty horrid syntax though. + class orderByManyNew[T] extends ProductArgs { + def applyProduct[U <: HList, O <: HList](columns: U) + (implicit + i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], + i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] + ): TypedWindow[T, OrderedWindow] = { + new TypedWindow[T, OrderedWindow]( + partitionSpec = Seq.empty, + orderSpec = i0(columns).toList[SortedTypedColumn[T, _]], + frame = UnspecifiedFrame + ) + } + } + + def partitionBy[T]( + column: TypedColumn[T, _] + ): TypedWindow[T, PartitionedWindow] = + new partitionByManyNew[T].apply(column) + + def partitionBy[T]( + column1: TypedColumn[T, _], + column2: TypedColumn[T, _] + ): TypedWindow[T, PartitionedWindow] = + new partitionByManyNew[T].apply(column1, column2) + + def partitionBy[T]( + column1: TypedColumn[T, _], + column2: TypedColumn[T, _], + column3: TypedColumn[T, _] + ): TypedWindow[T, PartitionedWindow] = + new partitionByManyNew[T].apply(column1, column2, column3) + + class partitionByManyNew[T] extends ProductArgs { + def applyProduct[U <: HList](columns: U) + (implicit + i1: ToTraversable.Aux[U, List, TypedColumn[T, _]] + ): TypedWindow[T, PartitionedWindow] = { + new TypedWindow[T, PartitionedWindow]( + partitionSpec = columns.toList[TypedColumn[T, _]], + orderSpec = Seq.empty, + frame = UnspecifiedFrame + ) + } + } +} + diff --git a/dataset/src/main/scala/frameless/functions/WindowFunctions.scala b/dataset/src/main/scala/frameless/functions/WindowFunctions.scala new file mode 100644 index 000000000..efadec86b --- /dev/null +++ b/dataset/src/main/scala/frameless/functions/WindowFunctions.scala @@ -0,0 +1,17 @@ +package frameless.functions + +import frameless.{ OrderedWindow, TypedColumn, TypedWindow } +import org.apache.spark.sql.{ functions => untyped } + +trait WindowFunctions { + + //TODO: TypedAggregate version that can be used in `agg` + // whose specs are all either aggs or in the groupBy. Not sure how to do the latter one + def denseRank[T, A <: OrderedWindow](over: TypedWindow[T, A]): TypedColumn[T, Int] = { + new TypedColumn[T, Int](untyped.dense_rank().over(over.untyped)) + } + +} + +//TODO: Move these to the other funcs? +object WindowFunctions extends WindowFunctions diff --git a/dataset/src/test/scala/frameless/WindowTests.scala b/dataset/src/test/scala/frameless/WindowTests.scala new file mode 100644 index 000000000..77e24ace7 --- /dev/null +++ b/dataset/src/test/scala/frameless/WindowTests.scala @@ -0,0 +1,41 @@ +package frameless + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.{ functions => sfuncs } +import frameless.functions.WindowFunctions._ +import org.scalacheck.Prop +import org.scalacheck.Prop._ + +object WindowTests { + case class Foo(a: Int, b: String) + case class FooRank(a: Int, b: String, rank: Int) +} + +class WindowTests extends TypedDatasetSuite { + test("dense rank") { + def prop[ + A : TypedEncoder, + B : TypedEncoder : CatalystOrdered + ](data: Vector[X2[A, B]]): Prop = { + val ds = TypedDataset.create(data) + + val untypedWindow = Window.partitionBy("a").orderBy("b") + + val untyped = TypedDataset.createUnsafe[X3[A, B, Int]](ds.toDF() + .withColumn("c", sfuncs.dense_rank().over(untypedWindow)) + ).collect().run().toVector + + val denseRankWindow = denseRank(TypedWindow.orderBy(ds[B]('b)) + .partitionBy(ds('a))) + + val typed = ds.withColumn[X3[A, B, Int]](denseRankWindow) + .collect().run().toVector + + typed ?= untyped + } + + check(forAll(prop[Int, String] _)) + check(forAll(prop[SQLDate, SQLDate] _)) + check(forAll(prop[String, Boolean] _)) + } +}