diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 7dc92ee6d..90c138e7d 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -8,8 +8,9 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter} import org.apache.spark.sql._ import shapeless._ -import shapeless.ops.hlist.{Prepend, ToTraversable, Tupler} -import shapeless.ops.record.{Remover, Values} +import shapeless.labelled.FieldType +import shapeless.ops.hlist.{Diff, IsHCons, Prepend, ToTraversable, Tupler} +import shapeless.ops.record.{Keys, Remover, Values} /** [[TypedDataset]] is a safer interface for working with `Dataset`. * @@ -656,10 +657,10 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val * {{{ * case class X(i: Int, j: Int) * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) - * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumn(f('j) === 10) + * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10) * }}} */ - def withColumn[A: TypedEncoder, H <: HList, FH <: HList, Out](ca: TypedColumn[T, A])( + def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out](ca: TypedColumn[T, A])( implicit genOfA: Generic.Aux[T, H], init: Prepend.Aux[H, A :: HNil, FH], @@ -672,6 +673,83 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val TypedDataset.create[Out](selected) } + + /** + * Adds a column to a Dataset so long as the specified output type, `U`, has + * an extra column from `T` that has type `A`. + * + * @example + * {{{ + * case class X(i: Int, j: Int) + * case class Y(i: Int, j: Int, k: Boolean) + * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + * val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10) + * }}} + * @param ca The typed column to add + * @param uEncder TypeEncoder for output type U + * @param aEncoder TypeEncoder for added column type `A` + * @param tgen the LabelledGeneric derived for T + * @param ugen the LabelledGeneric derived for U + * @param noRemovedOldFields proof no fields have been removed + * @param newFields diff from T to U + * @param newKeys keys from newFields + * @param newKey the one and only new key + * @param newField the one and only new field enforcing the type of A exists + * @param uKeys the keys of U + * @param uKeysTraverse allows for traversing the keys of U + * @tparam U the output type + * @tparam A The added column type + * @tparam TRep shapeless' record representation of T + * @tparam URep shapeless' record representation of U + * @tparam UKeys the keys of U as an HList + * @tparam NewFields the added fields to T to get U + * @tparam NewKeys the keys of NewFields as an HList + * @tparam NewKey the first, and only, key in NewKey + * + * @see [[frameless.TypedDataset.withColumnApply#apply]] + */ + def withColumn[U] = new withColumnApply[U] + + class withColumnApply[U] { + def apply[ + A, + TRep <: HList, + URep <: HList, + UKeys <: HList, + NewFields <: HList, + NewKeys <: HList, + NewKey <: Symbol + ]( + ca : TypedColumn[T, A] + )(implicit + uEncder: TypedEncoder[U], + aEncoder: TypedEncoder[A], + tgen: LabelledGeneric.Aux[T, TRep], + ugen: LabelledGeneric.Aux[U, URep], + noRemovedOldFields: Diff.Aux[TRep, URep, HNil], + newFields: Diff.Aux[URep, TRep, NewFields], + newKeys: Keys.Aux[NewFields, NewKeys], + newKey: IsHCons.Aux[NewKeys, NewKey, HNil], + newField: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil], + uKeys: Keys.Aux[URep, UKeys], + uKeysTraverse: ToTraversable.Aux[UKeys, Seq, Symbol] + ) = { + val newColumnName = + newKey.head(newKeys()).name + + val dfWithNewColumn = dataset + .toDF() + .withColumn(newColumnName, ca.untyped) + + val newColumns = uKeys.apply.to[Seq].map(_.name).map(dfWithNewColumn.col) + + val selected = dfWithNewColumn + .select(newColumns: _*) + .as[U](TypedExpressionEncoder[U]) + + TypedDataset.create[U](selected) + } + } } object TypedDataset { diff --git a/dataset/src/main/scala/frameless/ops/SmartProject.scala b/dataset/src/main/scala/frameless/ops/SmartProject.scala index aa3788397..94d7df3f6 100644 --- a/dataset/src/main/scala/frameless/ops/SmartProject.scala +++ b/dataset/src/main/scala/frameless/ops/SmartProject.scala @@ -19,7 +19,7 @@ object SmartProject { * @param tgen the LabelledGeneric derived for T * @param ugen the LabelledGeneric derived for U * @param keys the keys of U - * @param select selects all the keys of U from T + * @param select selects all the values from T using the keys of U * @param values selects all the values of LabeledGeneric[U] * @param typeEqualityProof proof that U and the projection of T have the same type * @param keysTraverse allows for traversing the keys of U diff --git a/dataset/src/test/scala/frameless/WithColumnTest.scala b/dataset/src/test/scala/frameless/WithColumnTest.scala index abf9a05da..f89497946 100644 --- a/dataset/src/test/scala/frameless/WithColumnTest.scala +++ b/dataset/src/test/scala/frameless/WithColumnTest.scala @@ -2,18 +2,48 @@ package frameless import org.scalacheck.Prop import org.scalacheck.Prop._ +import shapeless.test.illTyped class WithColumnTest extends TypedDatasetSuite { - test("append five columns") { + import WithColumnTest._ + + test("fail to compile on missing value") { + val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + illTyped { + """val fNew: TypedDataset[XMissing] = f.withColumn[XMissing](f('j) === 10)""" + } + } + + test("fail to compile on different column name") { + val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + illTyped { + """val fNew: TypedDataset[XDifferentColumnName] = f.withColumn[XDifferentColumnName](f('j) === 10)""" + } + } + + test("fail to compile on added column name") { + val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + illTyped { + """val fNew: TypedDataset[XAdded] = f.withColumn[XAdded](f('j) === 10)""" + } + } + + test("fail to compile on wrong typed column") { + val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + illTyped { + """val fNew: TypedDataset[XWrongType] = f.withColumn[XWrongType](f('j) === 10)""" + } + } + + test("append four columns") { def prop[A: TypedEncoder](value: A): Prop = { val d = TypedDataset.create(X1(value) :: Nil) - val d1 = d.withColumn(d('a)) - val d2 = d1.withColumn(d1('_1)) - val d3 = d2.withColumn(d2('_2)) - val d4 = d3.withColumn(d3('_3)) - val d5 = d4.withColumn(d4('_4)) + val d1 = d.withColumn[X2[A, A]](d('a)) + val d2 = d1.withColumn[X3[A, A, A]](d1('b)) + val d3 = d2.withColumn[X4[A, A, A, A]](d2('c)) + val d4 = d3.withColumn[X5[A, A, A, A, A]](d3('d)) - (value, value, value, value, value, value) ?= d5.collect().run().head + X5(value, value, value, value, value) ?= d4.collect().run().head } check(prop[Int] _) @@ -23,3 +53,12 @@ class WithColumnTest extends TypedDatasetSuite { check(prop[Option[X1[Boolean]]] _) } } + +object WithColumnTest { + case class X(i: Int, j: Int) + case class XMissing(i: Int, k: Boolean) + case class XDifferentColumnName(i: Int, ji: Int, k: Boolean) + case class XAdded(i: Int, j: Int, k: Boolean, l: Int) + case class XWrongType(i: Int, j: Int, k: Int) + case class XGood(i: Int, j: Int, k: Boolean) +} diff --git a/dataset/src/test/scala/frameless/WithColumnTupledTest.scala b/dataset/src/test/scala/frameless/WithColumnTupledTest.scala new file mode 100644 index 000000000..d8c49a921 --- /dev/null +++ b/dataset/src/test/scala/frameless/WithColumnTupledTest.scala @@ -0,0 +1,25 @@ +package frameless + +import org.scalacheck.Prop +import org.scalacheck.Prop._ + +class WithColumnTupledTest extends TypedDatasetSuite { + test("append five columns") { + def prop[A: TypedEncoder](value: A): Prop = { + val d = TypedDataset.create(X1(value) :: Nil) + val d1 = d.withColumnTupled(d('a)) + val d2 = d1.withColumnTupled(d1('_1)) + val d3 = d2.withColumnTupled(d2('_2)) + val d4 = d3.withColumnTupled(d3('_3)) + val d5 = d4.withColumnTupled(d4('_4)) + + (value, value, value, value, value, value) ?= d5.collect().run().head + } + + check(prop[Int] _) + check(prop[Long] _) + check(prop[String] _) + check(prop[SQLDate] _) + check(prop[Option[X1[Boolean]]] _) + } +}