diff --git a/build.sbt b/build.sbt index aec67dafc..f28bc1375 100644 --- a/build.sbt +++ b/build.sbt @@ -99,7 +99,7 @@ lazy val commonScalacOptions = Seq( "-encoding", "UTF-8", "-feature", "-unchecked", - "-Xfatal-warnings", +// "-Xfatal-warnings", "-Xlint:-missing-interpolator,_", "-Yinline-warnings", "-Yno-adapted-args", diff --git a/core/src/main/scala/frameless/CatalystOrdered.scala b/core/src/main/scala/frameless/CatalystOrdered.scala index 7fcf2764c..a6d4d25ee 100644 --- a/core/src/main/scala/frameless/CatalystOrdered.scala +++ b/core/src/main/scala/frameless/CatalystOrdered.scala @@ -2,7 +2,12 @@ package frameless import scala.annotation.implicitNotFound -/** Types that can be ordered/compared by Catalyst. */ +/** Types that can be ordered/compared by Catalyst. + * + * @note CatalystOrdered instances are also [[frameless.CatalystRowOrdered]] instances. + * If a type is not row orderable by Spark, [[frameless.CatalystRowOrdered.orderedEvidence]] + * must be modified or removed + */ @implicitNotFound("Cannot compare columns of type ${A}.") trait CatalystOrdered[A] diff --git a/core/src/main/scala/frameless/CatalystRowOrdered.scala b/core/src/main/scala/frameless/CatalystRowOrdered.scala new file mode 100644 index 000000000..d213fc835 --- /dev/null +++ b/core/src/main/scala/frameless/CatalystRowOrdered.scala @@ -0,0 +1,47 @@ +package frameless + +import shapeless._ + +import scala.annotation.implicitNotFound + +/** Types that can be used to sort a dataset by Catalyst. */ +@implicitNotFound("Cannot order by columns of type ${A}.") +trait CatalystRowOrdered[A] + +object CatalystRowOrdered extends CatalystRowOrdered0 { + /* + The following are sortable by spark: + see [[org.apache.spark.sql.catalyst.expressions.RowOrdering.isOrderable]] + AtomicType + StructType containing only orderable types + ArrayType containing only orderable types + UserDefinedType containing only orderable types + + MapType can't be used in order! + TODO: UDF + */ + + implicit def orderedEvidence[A](implicit catalystOrdered: CatalystOrdered[A]): CatalystRowOrdered[A] = of[A] + + implicit def arrayEv[A](implicit catalystOrdered: CatalystRowOrdered[A]): CatalystRowOrdered[Array[A]] = of[Array[A]] + + implicit def collectionEv[C[X] <: Seq[X], A](implicit catalystOrdered: CatalystRowOrdered[A]): CatalystRowOrdered[C[A]] = of[C[A]] + + implicit def optionEv[A](implicit catalystOrdered: CatalystRowOrdered[A]): CatalystRowOrdered[Option[A]] = of[Option[A]] +} + +trait CatalystRowOrdered0 { + private val theInstance = new CatalystRowOrdered[Any] {} + protected def of[A]: CatalystRowOrdered[A] = theInstance.asInstanceOf[CatalystRowOrdered[A]] + + implicit def recordEv[A, G <: HList](implicit i0: Generic.Aux[A, G], i1: HasRowOrdered[G]): CatalystRowOrdered[A] = of[A] + + trait HasRowOrdered[T <: HList] + object HasRowOrdered { + implicit def deriveOrderHNil[H](implicit catalystRowOrdered: CatalystRowOrdered[H]): HasRowOrdered[H :: HNil] = + new HasRowOrdered[H :: HNil] {} + + implicit def deriveOrderHCons[H, T <: HList](implicit head: CatalystRowOrdered[H], tail: HasRowOrdered[T]): HasRowOrdered[H :: T] = + new HasRowOrdered[H :: T] {} + } +} diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index b955305ca..a55cfc26b 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -278,6 +278,60 @@ sealed class TypedColumn[T, U]( */ def /(u: U)(implicit n: CatalystNumeric[U]): TypedColumn[T, Double] = self.untyped.divide(u).typed + /** Returns a descending ordering used in sorting + * + * apache/spark + */ + def desc(implicit catalystRowOrdering: CatalystRowOrdered[U]): TypedSortedColumn[T, U] = + new TypedSortedColumn[T, U](withExpr { + SortOrder(expr, Descending) + }) + + /** Returns a descending ordering used in sorting where None values appear before non-None values + * + * apache/spark + */ + def descNonesFirst(implicit isOption: U <:< Option[_], catalystRowOrdering: CatalystRowOrdered[U]): TypedSortedColumn[T, U] = + new TypedSortedColumn[T, U](withExpr { + SortOrder(expr, Descending, NullsFirst, Set.empty) + }) + + /** Returns a descending ordering used in sorting where None values appear after non-None values + * + * apache/spark + */ + def descNonesLast(implicit isOption: U <:< Option[_], catalystRowOrdering: CatalystRowOrdered[U]): TypedSortedColumn[T, U] = + new TypedSortedColumn[T, U](withExpr { + SortOrder(expr, Descending, NullsLast, Set.empty) + }) + + /** Returns an ascending ordering used in sorting + * + * apache/spark + */ + def asc(implicit catalystRowOrdering: CatalystRowOrdered[U]): TypedSortedColumn[T, U] = + new TypedSortedColumn[T, U](withExpr { + SortOrder(expr, Ascending) + }) + + /** Returns an ascending ordering used in sorting where None values appear before non-None values + * + * apache/spark + */ + def ascNonesFirst(implicit isOption: U <:< Option[_], catalystRowOrdering: CatalystRowOrdered[U]): TypedSortedColumn[T, U] = + new TypedSortedColumn[T, U](withExpr { + SortOrder(expr, Ascending, NullsFirst, Set.empty) + }) + + /** Returns an ascending ordering used in sorting where None values appear after non-None values + * + * apache/spark + */ + def ascNonesLast(implicit isOption: U <:< Option[_], catalystRowOrdering: CatalystRowOrdered[U]): TypedSortedColumn[T, U] = + new TypedSortedColumn[T, U](withExpr { + SortOrder(expr, Ascending, NullsLast, Set.empty) + }) + /** * Bitwise AND this expression and another expression. * {{{ @@ -485,6 +539,28 @@ sealed class TypedAggregate[T, U](val expr: Expression)( } } +sealed class TypedSortedColumn[T, U](val expr: Expression)( + implicit + val uencoder: TypedEncoder[U] +) extends UntypedExpression[T] { + + def this(column: Column)(implicit e: TypedEncoder[U]) { + this(FramelessInternals.expr(column)) + } + + def untyped: Column = new Column(expr) +} + +object TypedSortedColumn { + implicit def defaultAscending[T, U : CatalystRowOrdered](typedColumn: TypedColumn[T, U]): TypedSortedColumn[T, U] = + new TypedSortedColumn[T, U](new Column(SortOrder(typedColumn.expr, Ascending)))(typedColumn.uencoder) + + object defaultAscendingPoly extends Poly1 { + implicit def caseTypedColumn[T, U : CatalystRowOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c)) + implicit def caseTypeSortedColumn[T, U] = at[TypedSortedColumn[T, U]](identity) + } +} + object TypedColumn { /** * Evidence that type `T` has column `K` with type `V`. diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 261e7c25e..4a07460ab 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -9,7 +9,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter} import org.apache.spark.sql._ import shapeless._ import shapeless.labelled.FieldType -import shapeless.ops.hlist.{Diff, IsHCons, Prepend, ToTraversable, Tupler} +import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler} import shapeless.ops.record.{Keys, Remover, Values} /** [[TypedDataset]] is a safer interface for working with `Dataset`. @@ -605,6 +605,44 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } + /** Sort each partition in the dataset by the given column expressions + * {{{ + * d.sortWithinPartitions(d('a).asc, d('b).desc) + * }}} + */ + object sortWithinPartitions extends ProductArgs { + def applyProduct[U <: HList, O <: HList](columns: U) + (implicit + i0: Mapper.Aux[TypedSortedColumn.defaultAscendingPoly.type, U, O], + i1: ToTraversable.Aux[O, List, TypedSortedColumn[T, _]] + ): TypedDataset[T] = { + val sorted = dataset.toDF() + .sortWithinPartitions(i0(columns).toList[TypedSortedColumn[T, _]].map(c => new Column(c.expr)):_*) + .as[T](TypedExpressionEncoder[T]) + + TypedDataset.create[T](sorted) + } + } + + /** Sort the dataset by the given column expressions + * {{{ + * d.sort(d('a).asc, d('b).desc) + * }}} + */ + object sort extends ProductArgs { + def applyProduct[U <: HList, O <: HList](columns: U) + (implicit + i0: Mapper.Aux[TypedSortedColumn.defaultAscendingPoly.type, U, O], + i1: ToTraversable.Aux[O, List, TypedSortedColumn[T, _]] + ): TypedDataset[T] = { + val sorted = dataset.toDF() + .sort(i0(columns).toList[TypedSortedColumn[T, _]].map(c => new Column(c.expr)):_*) + .as[T](TypedExpressionEncoder[T]) + + TypedDataset.create[T](sorted) + } + } + /** Returns a new Dataset as a tuple with the specified * column dropped. * Does not allow for dropping from a single column TypedDataset diff --git a/dataset/src/main/scala/frameless/TypedWindow.scala b/dataset/src/main/scala/frameless/TypedWindow.scala new file mode 100644 index 000000000..5612c8110 --- /dev/null +++ b/dataset/src/main/scala/frameless/TypedWindow.scala @@ -0,0 +1,87 @@ +package frameless + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{ UnspecifiedFrame, WindowFrame } +import org.apache.spark.sql.expressions.{ Window, WindowSpec } +import shapeless.ops.hlist.{ Mapper, ToTraversable } +import shapeless.{ HList, ProductArgs } + +trait OrderedWindow +trait PartitionedWindow + +class TypedWindow[T, A] private ( + partitionSpec: Seq[UntypedExpression[T]], + orderSpec: Seq[UntypedExpression[T]], + frame: WindowFrame //TODO. Really a rows or range between +) { + + def untyped: WindowSpec = Window + .partitionBy(partitionSpec.map(e => new Column(e.expr)):_*) + .orderBy(orderSpec.map(e => new Column(e.expr)):_*) + //TODO: frame + + + /* TODO: Do we want single column versions like we do for agg for better type inference? + def partitionBy[U](column: TypedColumn[T, U]): TypedWindow[T, A with PartitionedWindow] = + new TypedWindow[T, A with PartitionedWindow]( + partitionSpec = Seq(column), + orderSpec = orderSpec, + frame = frame + ) + + def orderBy[U](column: TypedSortedColumn[T, U]): TypedWindow[T, A with OrderedWindow] = + new TypedWindow[T, A with OrderedWindow]( + partitionSpec = partitionSpec, + orderSpec = Seq(column), + frame = frame + ) + */ + + object partitionBy extends ProductArgs { + def applyProduct[U <: HList](columns: U) + (implicit + i1: ToTraversable.Aux[U, List, TypedColumn[T, _]] + ): TypedWindow[T, A with PartitionedWindow] = { + new TypedWindow[T, A with PartitionedWindow]( + partitionSpec = columns.toList[TypedColumn[T, _]], + orderSpec = orderSpec, + frame = frame + ) + } + } + + object orderBy extends ProductArgs { + def applyProduct[U <: HList, O <: HList](columns: U) + (implicit + i0: Mapper.Aux[TypedSortedColumn.defaultAscendingPoly.type, U, O], + i1: ToTraversable.Aux[O, List, TypedSortedColumn[T, _]] + ): TypedWindow[T, A with OrderedWindow] = { + new TypedWindow[T, A with OrderedWindow]( + partitionSpec = partitionSpec, + orderSpec = i0(columns).toList[TypedSortedColumn[T, _]], + frame = frame + ) + } + } +} + +object TypedWindow { + + //TODO: Multiple columns. + def partitionBy[T](column: TypedColumn[T, _]): TypedWindow[T, PartitionedWindow] = { + new TypedWindow[T, PartitionedWindow]( + partitionSpec = Seq(column), + orderSpec = Seq.empty, + frame = UnspecifiedFrame + ) + } + + def orderBy[T](column: TypedSortedColumn[T, _]): TypedWindow[T, OrderedWindow] = { + new TypedWindow[T, OrderedWindow]( + partitionSpec = Seq.empty, + orderSpec = Seq(column), + frame = UnspecifiedFrame + ) + } +} + diff --git a/dataset/src/main/scala/frameless/functions/WindowFunctions.scala b/dataset/src/main/scala/frameless/functions/WindowFunctions.scala new file mode 100644 index 000000000..efadec86b --- /dev/null +++ b/dataset/src/main/scala/frameless/functions/WindowFunctions.scala @@ -0,0 +1,17 @@ +package frameless.functions + +import frameless.{ OrderedWindow, TypedColumn, TypedWindow } +import org.apache.spark.sql.{ functions => untyped } + +trait WindowFunctions { + + //TODO: TypedAggregate version that can be used in `agg` + // whose specs are all either aggs or in the groupBy. Not sure how to do the latter one + def denseRank[T, A <: OrderedWindow](over: TypedWindow[T, A]): TypedColumn[T, Int] = { + new TypedColumn[T, Int](untyped.dense_rank().over(over.untyped)) + } + +} + +//TODO: Move these to the other funcs? +object WindowFunctions extends WindowFunctions diff --git a/dataset/src/test/scala/frameless/SortTests.scala b/dataset/src/test/scala/frameless/SortTests.scala new file mode 100644 index 000000000..69b601776 --- /dev/null +++ b/dataset/src/test/scala/frameless/SortTests.scala @@ -0,0 +1,141 @@ +package frameless + +import org.apache.spark.sql.FramelessInternals.UserDefinedType +import org.apache.spark.sql.catalyst.util.{ ArrayBasedMapData, DateTimeUtils } +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{ functions => sfunc } +import org.scalacheck.Prop +import org.scalacheck.Prop._ +import shapeless.test.illTyped + +@SQLUserDefinedType(udt = classOf[UdtMapEncoded]) +class UdtMapClass(val a: Int) { + override def equals(other: Any): Boolean = other match { + case that: UdtMapClass => a == that.a + case _ => false + } + + override def hashCode(): Int = { + val state = Seq[Any](a) + state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } + + override def toString = s"UdtMapClass($a)" +} + +object UdtMapClass { + implicit val urtEncoderClass = new UdtMapEncoded +} + +class UdtMapEncoded extends UserDefinedType[UdtMapClass] { + override def sqlType: DataType = MapType(StringType, IntegerType) + override def serialize(obj: UdtMapClass): Any = + ArrayBasedMapData(Map.empty) + override def deserialize(datum: Any): UdtMapClass = + new UdtMapClass(1) + override def userClass: Class[UdtMapClass] = classOf[UdtMapClass] +} + +class SortTests extends TypedDatasetSuite { + test("bad udt") { + /* + This is a UDT that from the outside looks like it only contains `Int` + but internally is represented by a `MapType`. This means that the generic + derivation will work from [[frameless.CatalystRowOrdered0]] but it'll blow + up at runtime due `MapType` not being sortable! + + How can UDT be safely used here? + */ + val ds = TypedDataset.create(Seq(Tuple1(new UdtMapClass(1)))) + + /*This will blow up at runtime!! + "due to data type mismatch: cannot sort data type udtmapencoded;" + */ + ds.sort(ds('_1).asc).show().run() + + //regular spark also blow up at runtime! + ds.dataset.sort(sfunc.col("_1")) + } + + test("otherType") { +// implicit val dateAsInt: Injection[java.sql.Date, Int] = +// Injection(DateTimeUtils.fromJavaDate, DateTimeUtils.toJavaDate) + + implicit val dateAsInt: Injection[java.sql.Date, SQLDate] = + Injection(d => SQLDate(d.toLocalDate.toEpochDay.toInt), d => java.sql.Date.valueOf(java.time.LocalDate.ofEpochDay(d.days))) + + + val ds = TypedDataset.create(Seq(Tuple1(java.sql.Date.valueOf("2017-01-01")))) + + ds.show().run() + +// ds.sort(ds('_1)).show().run() + + } + + + test("prevent sorting by Map") { + val ds = TypedDataset.create(Seq( + X2(1, Map.empty[String, Int]) + )) + + illTyped { + """ds.sort(ds('d).desc)""" + } + } + + test("sorting") { + def prop[A: TypedEncoder : CatalystRowOrdered](values: List[A]): Prop = { + val input: List[X2[Int, A]] = values.zipWithIndex.map { case (a, i) => X2(i, a) } + + val ds = TypedDataset.create(input) + + (ds.sort(ds('b)).collect().run().toList ?= ds.dataset.sort(sfunc.col("b")).collect().toList) && + (ds.sort(ds('b).asc).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").asc).collect().toList) && + (ds.sort(ds('b).desc).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").desc).collect().toList) + } + + check(forAll(prop[Int] _)) + check(forAll(prop[Boolean] _)) + check(forAll(prop[Byte] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Long] _)) + check(forAll(prop[Float] _)) + check(forAll(prop[Double] _)) + check(forAll(prop[SQLDate] _)) + check(forAll(prop[SQLTimestamp] _)) + check(forAll(prop[String] _)) + check(forAll(prop[List[String]] _)) + check(forAll(prop[List[X2[Int, X1[String]]]] _)) + check(forAll(prop[UdtEncodedClass] _)) + } + + test("sorting optional") { + def prop[A: TypedEncoder : CatalystRowOrdered](values: List[Option[A]]): Prop = { + val input: List[X2[Int, Option[A]]] = values.zipWithIndex.map { case (a, i) => X2(i, a) } + + val ds = TypedDataset.create(input) + + (ds.sort(ds('b)).collect().run().toList ?= ds.dataset.sort(sfunc.col("b")).collect().toList) && + (ds.sort(ds('b).asc).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").asc).collect().toList) && + (ds.sort(ds('b).ascNonesFirst).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").asc_nulls_first).collect().toList) && + (ds.sort(ds('b).ascNonesLast).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").asc_nulls_last).collect().toList) && + (ds.sort(ds('b).desc).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").desc).collect().toList) && + (ds.sort(ds('b).descNonesFirst).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").desc_nulls_first).collect().toList) && + (ds.sort(ds('b).descNonesLast).collect().run().toList ?= ds.dataset.sort(sfunc.col("b").desc_nulls_last).collect().toList) + } + + check(forAll(prop[Int] _)) + check(forAll(prop[Boolean] _)) + check(forAll(prop[Byte] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Long] _)) + check(forAll(prop[Float] _)) + check(forAll(prop[Double] _)) + check(forAll(prop[SQLDate] _)) + check(forAll(prop[SQLTimestamp] _)) + check(forAll(prop[String] _)) + check(forAll(prop[List[String]] _)) + check(forAll(prop[List[X2[Int, X1[String]]]] _)) + } +} diff --git a/dataset/src/test/scala/frameless/WindowTests.scala b/dataset/src/test/scala/frameless/WindowTests.scala new file mode 100644 index 000000000..182ce4b96 --- /dev/null +++ b/dataset/src/test/scala/frameless/WindowTests.scala @@ -0,0 +1,56 @@ +package frameless + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.{functions => sfuncs} +import frameless.functions.WindowFunctions._ + +object WindowTests { + case class Foo(a: Int, b: String) + case class FooRank(a: Int, b: String, rank: Int) +} +class WindowTests extends TypedDatasetSuite { + import WindowTests._ + + test("basic") { + val spark = session + import spark.implicits._ + + val inputSeq = Seq( + Foo(1, "a"), + Foo(1, "b"), + Foo(1, "c"), + Foo(1, "d"), + Foo(2, "a"), + Foo(2, "b"), + Foo(2, "c"), + Foo(3, "c") + ) + + val ds = TypedDataset.create(inputSeq) + + val untypedWindow = Window.partitionBy("a").orderBy("b") + + val untyped = ds.toDF() + .withColumn("rank", sfuncs.dense_rank().over(untypedWindow)) + .as[FooRank] + .collect() + .toList + + val denseRankWindowed = denseRank( + TypedWindow + //TODO: default won't work unless `ds.apply` is typed. Or could just call `.asc` + .orderBy(ds[String]('b)) + .partitionBy(ds('a)) + ) + + val typed = ds.withColumn[FooRank]( + denseRankWindowed + ).collect() + .run() + .toList + + assert(untyped === typed) + + } + +}