Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/frameless/CatalystOrdered.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ object CatalystOrdered {
injection: Injection[A, B],
ordered: CatalystOrdered[B]
) : CatalystOrdered[A] = of[A]
}
}
18 changes: 18 additions & 0 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ sealed trait UntypedExpression[T] {
override def toString: String = expr.toString()
}

/** Expression used while sorting a TypedDataset. It prevents all other future column operations since they
* lead in runtime errors in Spark.
*
* @tparam T type of dataset
* @tparam U type of column
*/
sealed class SortedTypedColumn[T, U](val expr: Expression)(
implicit val uencoder: TypedEncoder[U]
) extends UntypedExpression[T]

/** Expression used in `select`-like constructions.
*
* Documentation marked "apache/spark" is thanks to apache/spark Contributors
Expand Down Expand Up @@ -53,6 +63,14 @@ sealed class TypedColumn[T, U](
else EqualTo(self.expr, other.expr)
}.typed

/** Prepares the column to be used for sorting in descending order by converting it to a [[SortedTypedColumn]]
*/
def desc: SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](FramelessInternals.expr(untyped.desc))

/** Prepares the column to be used for sorting in ascending order by converting it to a [[SortedTypedColumn]]
*/
def asc: SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](FramelessInternals.expr(untyped.asc))

/** Equality test.
* {{{
* df.filter( df.col('a) === 1 )
Expand Down
31 changes: 31 additions & 0 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,37 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
TypedDataset.create[U](selected)
}
}

/** Orders the TypedDataset using any number of columns.
*/
object orderByMany extends ProductArgs {
def applyProduct[U <: HList](columns: U)
(implicit
i0: SortableColumnTypes[T, U],
i1: ToTraversable.Aux[U, List, UntypedExpression[T]]): TypedDataset[T] = {
val selected = dataset.toDF()
.orderBy(columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)):_*)
TypedDataset.createUnsafe[T](selected)
}
}

/** Orders the TypedDataset using the column selected.
*/
def orderBy[A: CatalystOrdered]
(ca: SortedTypedColumn[T, A]): TypedDataset[T] = orderByMany(ca)

/** Orders the TypedDataset using the columns selected.
*/
def orderBy[A: CatalystOrdered, B: CatalystOrdered]
(ca: SortedTypedColumn[T, A],
cb: SortedTypedColumn[T, B]): TypedDataset[T] = orderByMany(ca, cb)

/** Orders the TypedDataset using the columns selected.
*/
def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]
(ca: SortedTypedColumn[T, A],
cb: SortedTypedColumn[T, B],
cc: SortedTypedColumn[T, C]): TypedDataset[T] = orderByMany(ca, cb, cc)
}

object TypedDataset {
Expand Down
5 changes: 4 additions & 1 deletion dataset/src/main/scala/frameless/ops/ColumnTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package ops

import shapeless._

import scala.annotation.implicitNotFound

/** A type class to extract the column types out of an HList of [[frameless.TypedColumn]].
*
* @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped.
Expand All @@ -12,6 +14,7 @@ import shapeless._
* type Out = A :: B :: C :: HNil
* }}}
*/
@implicitNotFound("Unable to proof that all arguments in ${U} are TypedColumns")
trait ColumnTypes[T, U <: HList] {
type Out <: HList
}
Expand All @@ -25,4 +28,4 @@ object ColumnTypes {
implicit tail: ColumnTypes.Aux[T, TT, V]
): ColumnTypes.Aux[T, TypedColumn[T, H] :: TT, H :: V] =
new ColumnTypes[T, TypedColumn[T, H] :: TT] {type Out = H :: V}
}
}
24 changes: 24 additions & 0 deletions dataset/src/main/scala/frameless/ops/SortableColumnTypes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package frameless.ops

import frameless.{CatalystOrdered, SortedTypedColumn}
import shapeless.{::, HList, HNil}

import scala.annotation.implicitNotFound

@implicitNotFound(
"Either one of the selected columns is not sortable (${U}), " +
"or no ordering (ascending/descending) has been selected. " +
"Select an ordering on any column (t) using the t.asc or t.desc methods.")
trait SortableColumnTypes[T, U <: HList]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be something in shapeless doing just that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, it's called Comapped

Copy link
Contributor Author

@imarios imarios Jan 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for all the comments @OlivierBlanvillain. This one does two things, enforces that they are CatalystOrdered and that they are SortedTypedColumn. It also allows us to give the ability to write a custom error compilation error (using implicitNotFound). FInally, our ColumnTypes already is a small rework of Comapped and there @kanterov mentioned a performance reason why we prefer our own version rather than relying on Comapped.


object SortableColumnTypes {
implicit def deriveHNil[T]: SortableColumnTypes[T, HNil] =
new SortableColumnTypes[T, HNil] { type Out = HNil }

implicit def deriveCons[T, H, TT <: HList, V <: HList]
(implicit
order: CatalystOrdered[H],
tail: SortableColumnTypes[T, TT]
): SortableColumnTypes[T, SortedTypedColumn[T, H] :: TT] =
new SortableColumnTypes[T, SortedTypedColumn[T, H] :: TT] {}
}
77 changes: 77 additions & 0 deletions dataset/src/test/scala/frameless/OrderByTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._
import org.scalatest.Matchers
import shapeless.test.illTyped

class OrderByTests extends TypedDatasetSuite with Matchers {
test("sorting single column descending") {
def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = {
val ds = TypedDataset.create(data)

ds.dataset.orderBy(ds.dataset.col("a").desc).collect().toVector.?=(
ds.orderBy(ds('a).desc).collect().run().toVector)
}

check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}

test("sorting single column ascending") {
def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = {
val ds = TypedDataset.create(data)

ds.dataset.orderBy(ds.dataset.col("a").asc).collect().toVector.?=(
ds.orderBy(ds('a).asc).collect().run().toVector)
}

check(forAll(prop[Int] _))
check(forAll(prop[String] _))
check(forAll(prop[Boolean] _))
}

test("sorting on two columns") {
def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = {
val ds = TypedDataset.create(data)

val vanillaSpark = ds.dataset.orderBy(ds.dataset.col("a").asc, ds.dataset.col("b").desc).collect().toVector
vanillaSpark.?=(ds.orderByMany(ds('a).asc, ds('b).desc).collect().run().toVector).&&(
vanillaSpark ?= ds.orderBy(ds('a).asc, ds('b).desc).collect().run().toVector
)
}

check(forAll(prop[SQLDate, Long] _))
check(forAll(prop[String, Boolean] _))
check(forAll(prop[SQLTimestamp, Long] _))
}

test("sorting on three columns") {
def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered]
(data: Vector[X3[A, B, A]]): Prop = {
val ds = TypedDataset.create(data)

val vanillaSpark = ds.dataset.orderBy(
ds.dataset.col("a").desc,
ds.dataset.col("b").desc,
ds.dataset.col("c").asc
).collect().toVector

vanillaSpark.?=(ds.orderByMany(ds('a).desc, ds('b).desc, ds('c).asc).collect().run().toVector).&&(
vanillaSpark ?= ds.orderBy(ds('a).desc, ds('b).desc, ds('c).asc).collect().run().toVector)
}

check(forAll(prop[Int, Long] _))
check(forAll(prop[String, SQLDate] _))
check(forAll(prop[Boolean, Long] _))
}

test("fail when selected column is not sortable") {
val d = TypedDataset.create(X2(1, List(1)) :: X2(2, List(2)) :: Nil)
d.orderBy(d('a).desc)
illTyped("""d.orderByDesc('b)""")
d.orderByMany(d('a).desc)
illTyped("""d.orderByMany(d('b).desc)""")
illTyped("""d.orderByMany(d('a))""") // column is correct, but no ordering is selected
}
}