diff --git a/core/src/main/scala/frameless/CatalystOrdered.scala b/core/src/main/scala/frameless/CatalystOrdered.scala index efba3a778..db5b62875 100644 --- a/core/src/main/scala/frameless/CatalystOrdered.scala +++ b/core/src/main/scala/frameless/CatalystOrdered.scala @@ -27,4 +27,4 @@ object CatalystOrdered { injection: Injection[A, B], ordered: CatalystOrdered[B] ) : CatalystOrdered[A] = of[A] -} +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index b955305ca..9209b7684 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -18,6 +18,16 @@ sealed trait UntypedExpression[T] { override def toString: String = expr.toString() } +/** Expression used while sorting a TypedDataset. It prevents all other future column operations since they + * lead in runtime errors in Spark. + * + * @tparam T type of dataset + * @tparam U type of column + */ +sealed class SortedTypedColumn[T, U](val expr: Expression)( + implicit val uencoder: TypedEncoder[U] +) extends UntypedExpression[T] + /** Expression used in `select`-like constructions. * * Documentation marked "apache/spark" is thanks to apache/spark Contributors @@ -53,6 +63,14 @@ sealed class TypedColumn[T, U]( else EqualTo(self.expr, other.expr) }.typed + /** Prepares the column to be used for sorting in descending order by converting it to a [[SortedTypedColumn]] + */ + def desc: SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](FramelessInternals.expr(untyped.desc)) + + /** Prepares the column to be used for sorting in ascending order by converting it to a [[SortedTypedColumn]] + */ + def asc: SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](FramelessInternals.expr(untyped.asc)) + /** Equality test. * {{{ * df.filter( df.col('a) === 1 ) diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 261e7c25e..9e1478f23 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -772,6 +772,37 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val TypedDataset.create[U](selected) } } + + /** Orders the TypedDataset using any number of columns. + */ + object orderByMany extends ProductArgs { + def applyProduct[U <: HList](columns: U) + (implicit + i0: SortableColumnTypes[T, U], + i1: ToTraversable.Aux[U, List, UntypedExpression[T]]): TypedDataset[T] = { + val selected = dataset.toDF() + .orderBy(columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)):_*) + TypedDataset.createUnsafe[T](selected) + } + } + + /** Orders the TypedDataset using the column selected. + */ + def orderBy[A: CatalystOrdered] + (ca: SortedTypedColumn[T, A]): TypedDataset[T] = orderByMany(ca) + + /** Orders the TypedDataset using the columns selected. + */ + def orderBy[A: CatalystOrdered, B: CatalystOrdered] + (ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B]): TypedDataset[T] = orderByMany(ca, cb) + + /** Orders the TypedDataset using the columns selected. + */ + def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered] + (ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B], + cc: SortedTypedColumn[T, C]): TypedDataset[T] = orderByMany(ca, cb, cc) } object TypedDataset { diff --git a/dataset/src/main/scala/frameless/ops/ColumnTypes.scala b/dataset/src/main/scala/frameless/ops/ColumnTypes.scala index e5ae6aea2..0195b4b00 100644 --- a/dataset/src/main/scala/frameless/ops/ColumnTypes.scala +++ b/dataset/src/main/scala/frameless/ops/ColumnTypes.scala @@ -3,6 +3,8 @@ package ops import shapeless._ +import scala.annotation.implicitNotFound + /** A type class to extract the column types out of an HList of [[frameless.TypedColumn]]. * * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped. @@ -12,6 +14,7 @@ import shapeless._ * type Out = A :: B :: C :: HNil * }}} */ +@implicitNotFound("Unable to proof that all arguments in ${U} are TypedColumns") trait ColumnTypes[T, U <: HList] { type Out <: HList } @@ -25,4 +28,4 @@ object ColumnTypes { implicit tail: ColumnTypes.Aux[T, TT, V] ): ColumnTypes.Aux[T, TypedColumn[T, H] :: TT, H :: V] = new ColumnTypes[T, TypedColumn[T, H] :: TT] {type Out = H :: V} -} +} \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/ops/SortableColumnTypes.scala b/dataset/src/main/scala/frameless/ops/SortableColumnTypes.scala new file mode 100644 index 000000000..1ead5df73 --- /dev/null +++ b/dataset/src/main/scala/frameless/ops/SortableColumnTypes.scala @@ -0,0 +1,24 @@ +package frameless.ops + +import frameless.{CatalystOrdered, SortedTypedColumn} +import shapeless.{::, HList, HNil} + +import scala.annotation.implicitNotFound + +@implicitNotFound( + "Either one of the selected columns is not sortable (${U}), " + + "or no ordering (ascending/descending) has been selected. " + + "Select an ordering on any column (t) using the t.asc or t.desc methods.") +trait SortableColumnTypes[T, U <: HList] + +object SortableColumnTypes { + implicit def deriveHNil[T]: SortableColumnTypes[T, HNil] = + new SortableColumnTypes[T, HNil] { type Out = HNil } + + implicit def deriveCons[T, H, TT <: HList, V <: HList] + (implicit + order: CatalystOrdered[H], + tail: SortableColumnTypes[T, TT] + ): SortableColumnTypes[T, SortedTypedColumn[T, H] :: TT] = + new SortableColumnTypes[T, SortedTypedColumn[T, H] :: TT] {} +} diff --git a/dataset/src/test/scala/frameless/OrderByTests.scala b/dataset/src/test/scala/frameless/OrderByTests.scala new file mode 100644 index 000000000..c7ee2d609 --- /dev/null +++ b/dataset/src/test/scala/frameless/OrderByTests.scala @@ -0,0 +1,77 @@ +package frameless + +import org.scalacheck.Prop +import org.scalacheck.Prop._ +import org.scalatest.Matchers +import shapeless.test.illTyped + +class OrderByTests extends TypedDatasetSuite with Matchers { + test("sorting single column descending") { + def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + val ds = TypedDataset.create(data) + + ds.dataset.orderBy(ds.dataset.col("a").desc).collect().toVector.?=( + ds.orderBy(ds('a).desc).collect().run().toVector) + } + + check(forAll(prop[Int] _)) + check(forAll(prop[String] _)) + } + + test("sorting single column ascending") { + def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + val ds = TypedDataset.create(data) + + ds.dataset.orderBy(ds.dataset.col("a").asc).collect().toVector.?=( + ds.orderBy(ds('a).asc).collect().run().toVector) + } + + check(forAll(prop[Int] _)) + check(forAll(prop[String] _)) + check(forAll(prop[Boolean] _)) + } + + test("sorting on two columns") { + def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = { + val ds = TypedDataset.create(data) + + val vanillaSpark = ds.dataset.orderBy(ds.dataset.col("a").asc, ds.dataset.col("b").desc).collect().toVector + vanillaSpark.?=(ds.orderByMany(ds('a).asc, ds('b).desc).collect().run().toVector).&&( + vanillaSpark ?= ds.orderBy(ds('a).asc, ds('b).desc).collect().run().toVector + ) + } + + check(forAll(prop[SQLDate, Long] _)) + check(forAll(prop[String, Boolean] _)) + check(forAll(prop[SQLTimestamp, Long] _)) + } + + test("sorting on three columns") { + def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered] + (data: Vector[X3[A, B, A]]): Prop = { + val ds = TypedDataset.create(data) + + val vanillaSpark = ds.dataset.orderBy( + ds.dataset.col("a").desc, + ds.dataset.col("b").desc, + ds.dataset.col("c").asc + ).collect().toVector + + vanillaSpark.?=(ds.orderByMany(ds('a).desc, ds('b).desc, ds('c).asc).collect().run().toVector).&&( + vanillaSpark ?= ds.orderBy(ds('a).desc, ds('b).desc, ds('c).asc).collect().run().toVector) + } + + check(forAll(prop[Int, Long] _)) + check(forAll(prop[String, SQLDate] _)) + check(forAll(prop[Boolean, Long] _)) + } + + test("fail when selected column is not sortable") { + val d = TypedDataset.create(X2(1, List(1)) :: X2(2, List(2)) :: Nil) + d.orderBy(d('a).desc) + illTyped("""d.orderByDesc('b)""") + d.orderByMany(d('a).desc) + illTyped("""d.orderByMany(d('b).desc)""") + illTyped("""d.orderByMany(d('a))""") // column is correct, but no ordering is selected + } +}