Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter}
import org.apache.spark.sql._
import shapeless._
import shapeless.ops.hlist.{Prepend, ToTraversable, Tupler}
import shapeless.ops.record.{Remover, Values}
import shapeless.labelled.FieldType
import shapeless.ops.hlist.{Diff, IsHCons, Prepend, ToTraversable, Tupler}
import shapeless.ops.record.{Keys, Remover, Values}

/** [[TypedDataset]] is a safer interface for working with `Dataset`.
*
Expand Down Expand Up @@ -656,10 +657,10 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
* {{{
* case class X(i: Int, j: Int)
* val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
* val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumn(f('j) === 10)
* val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10)
* }}}
*/
def withColumn[A: TypedEncoder, H <: HList, FH <: HList, Out](ca: TypedColumn[T, A])(
def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out](ca: TypedColumn[T, A])(
implicit
genOfA: Generic.Aux[T, H],
init: Prepend.Aux[H, A :: HNil, FH],
Expand All @@ -672,6 +673,83 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val

TypedDataset.create[Out](selected)
}

/**
* Adds a column to a Dataset so long as the specified output type, `U`, has
* an extra column from `T` that has type `A`.
*
* @example
* {{{
* case class X(i: Int, j: Int)
* case class Y(i: Int, j: Int, k: Boolean)
* val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
* val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10)
* }}}
* @param ca The typed column to add
* @param uEncder TypeEncoder for output type U
* @param aEncoder TypeEncoder for added column type `A`
* @param tgen the LabelledGeneric derived for T
* @param ugen the LabelledGeneric derived for U
* @param noRemovedOldFields proof no fields have been removed
* @param newFields diff from T to U
* @param newKeys keys from newFields
* @param newKey the one and only new key
* @param newField the one and only new field enforcing the type of A exists
* @param uKeys the keys of U
* @param uKeysTraverse allows for traversing the keys of U
* @tparam U the output type
* @tparam A The added column type
* @tparam TRep shapeless' record representation of T
* @tparam URep shapeless' record representation of U
* @tparam UKeys the keys of U as an HList
* @tparam NewFields the added fields to T to get U
* @tparam NewKeys the keys of NewFields as an HList
* @tparam NewKey the first, and only, key in NewKey
*
* @see [[frameless.TypedDataset.withColumnApply#apply]]
*/
def withColumn[U] = new withColumnApply[U]

class withColumnApply[U] {
def apply[
A,
TRep <: HList,
URep <: HList,
UKeys <: HList,
NewFields <: HList,
NewKeys <: HList,
NewKey <: Symbol
](
ca : TypedColumn[T, A]
)(implicit
uEncder: TypedEncoder[U],
aEncoder: TypedEncoder[A],
tgen: LabelledGeneric.Aux[T, TRep],
ugen: LabelledGeneric.Aux[U, URep],
noRemovedOldFields: Diff.Aux[TRep, URep, HNil],
newFields: Diff.Aux[URep, TRep, NewFields],
newKeys: Keys.Aux[NewFields, NewKeys],
newKey: IsHCons.Aux[NewKeys, NewKey, HNil],
newField: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil],
uKeys: Keys.Aux[URep, UKeys],
uKeysTraverse: ToTraversable.Aux[UKeys, Seq, Symbol]
) = {
val newColumnName =
newKey.head(newKeys()).name

val dfWithNewColumn = dataset
.toDF()
.withColumn(newColumnName, ca.untyped)

val newColumns = uKeys.apply.to[Seq].map(_.name).map(dfWithNewColumn.col)

val selected = dfWithNewColumn
.select(newColumns: _*)
.as[U](TypedExpressionEncoder[U])

TypedDataset.create[U](selected)
}
}
}

object TypedDataset {
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/ops/SmartProject.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ object SmartProject {
* @param tgen the LabelledGeneric derived for T
* @param ugen the LabelledGeneric derived for U
* @param keys the keys of U
* @param select selects all the keys of U from T
* @param select selects all the values from T using the keys of U
* @param values selects all the values of LabeledGeneric[U]
* @param typeEqualityProof proof that U and the projection of T have the same type
* @param keysTraverse allows for traversing the keys of U
Expand Down
53 changes: 46 additions & 7 deletions dataset/src/test/scala/frameless/WithColumnTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,48 @@ package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._
import shapeless.test.illTyped

class WithColumnTest extends TypedDatasetSuite {
test("append five columns") {
import WithColumnTest._

test("fail to compile on missing value") {
val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XMissing] = f.withColumn[XMissing](f('j) === 10)"""
}
}

test("fail to compile on different column name") {
val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XDifferentColumnName] = f.withColumn[XDifferentColumnName](f('j) === 10)"""
}
}

test("fail to compile on added column name") {
val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XAdded] = f.withColumn[XAdded](f('j) === 10)"""
}
}

test("fail to compile on wrong typed column") {
val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XWrongType] = f.withColumn[XWrongType](f('j) === 10)"""
}
}

test("append four columns") {
def prop[A: TypedEncoder](value: A): Prop = {
val d = TypedDataset.create(X1(value) :: Nil)
val d1 = d.withColumn(d('a))
val d2 = d1.withColumn(d1('_1))
val d3 = d2.withColumn(d2('_2))
val d4 = d3.withColumn(d3('_3))
val d5 = d4.withColumn(d4('_4))
val d1 = d.withColumn[X2[A, A]](d('a))
val d2 = d1.withColumn[X3[A, A, A]](d1('b))
val d3 = d2.withColumn[X4[A, A, A, A]](d2('c))
val d4 = d3.withColumn[X5[A, A, A, A, A]](d3('d))

(value, value, value, value, value, value) ?= d5.collect().run().head
X5(value, value, value, value, value) ?= d4.collect().run().head
}

check(prop[Int] _)
Expand All @@ -23,3 +53,12 @@ class WithColumnTest extends TypedDatasetSuite {
check(prop[Option[X1[Boolean]]] _)
}
}

object WithColumnTest {
case class X(i: Int, j: Int)
case class XMissing(i: Int, k: Boolean)
case class XDifferentColumnName(i: Int, ji: Int, k: Boolean)
case class XAdded(i: Int, j: Int, k: Boolean, l: Int)
case class XWrongType(i: Int, j: Int, k: Int)
case class XGood(i: Int, j: Int, k: Boolean)
}
25 changes: 25 additions & 0 deletions dataset/src/test/scala/frameless/WithColumnTupledTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._

class WithColumnTupledTest extends TypedDatasetSuite {
test("append five columns") {
def prop[A: TypedEncoder](value: A): Prop = {
val d = TypedDataset.create(X1(value) :: Nil)
val d1 = d.withColumnTupled(d('a))
val d2 = d1.withColumnTupled(d1('_1))
val d3 = d2.withColumnTupled(d2('_2))
val d4 = d3.withColumnTupled(d3('_3))
val d5 = d4.withColumnTupled(d4('_4))

(value, value, value, value, value, value) ?= d5.collect().run().head
}

check(prop[Int] _)
check(prop[Long] _)
check(prop[String] _)
check(prop[SQLDate] _)
check(prop[Option[X1[Boolean]]] _)
}
}