diff --git a/build.sbt b/build.sbt index 4048d31..16400f3 100644 --- a/build.sbt +++ b/build.sbt @@ -18,7 +18,7 @@ lazy val compilerOptions = Seq( ) lazy val shapelessVersion = "2.3.2" -lazy val sparkVersion = "2.1.0" +lazy val sparkVersion = "2.1.2" lazy val scalatestVersion = "3.0.1" lazy val baseSettings = Seq( @@ -28,6 +28,7 @@ lazy val baseSettings = Seq( ) ++ Seq( "org.scalatest" %% "scalatest" % scalatestVersion ).map(_ % "test"), + parallelExecution in Test := false, scalacOptions ++= compilerOptions ) diff --git a/core/src/main/scala/ste/encoder.scala b/core/src/main/scala/ste/encoder.scala index 925bf77..e1f7d0f 100644 --- a/core/src/main/scala/ste/encoder.scala +++ b/core/src/main/scala/ste/encoder.scala @@ -29,7 +29,8 @@ import scala.annotation.StaticAnnotation import scala.collection.generic.IsTraversableOnce -final class Meta(val metadata: Metadata) extends StaticAnnotation +final case class Meta(metadata: Metadata) extends StaticAnnotation +final case class Flatten(times: Int = 1, keys: Seq[String] = Seq()) extends StaticAnnotation @annotation.implicitNotFound(""" Type ${A} does not have a DataTypeEncoder defined in the library. @@ -37,15 +38,17 @@ final class Meta(val metadata: Metadata) extends StaticAnnotation """) sealed trait DataTypeEncoder[A] { def encode: DataType + def fields: Option[Seq[StructField]] def nullable: Boolean } object DataTypeEncoder { def apply[A](implicit enc: DataTypeEncoder[A]): DataTypeEncoder[A] = enc - def pure[A](dt: DataType, isNullable: Boolean = false): DataTypeEncoder[A] = + def pure[A](dt: DataType, f: Option[Seq[StructField]] = None, isNullable: Boolean = false): DataTypeEncoder[A] = new DataTypeEncoder[A] { def encode: DataType = dt + def fields: Option[Seq[StructField]] = f def nullable: Boolean = isNullable } } @@ -56,6 +59,7 @@ object DataTypeEncoder { """) sealed trait StructTypeEncoder[A] extends DataTypeEncoder[A] { def encode: StructType + def fields: Option[Seq[StructField]] def nullable: Boolean } @@ -65,6 +69,7 @@ object StructTypeEncoder extends MediumPriorityImplicits { def pure[A](st: StructType, isNullable: Boolean = false): StructTypeEncoder[A] = new StructTypeEncoder[A] { def encode: StructType = st + def fields: Option[Seq[StructField]] = Some(st.fields.toSeq) def nullable: Boolean = isNullable } } @@ -80,7 +85,7 @@ sealed trait AnnotatedStructTypeEncoder[A] { } object AnnotatedStructTypeEncoder extends MediumPriorityImplicits { - type Encode = Seq[Metadata] => StructType + type Encode = (Seq[Metadata], Seq[Option[Flatten]]) => StructType def pure[A](enc: Encode): AnnotatedStructTypeEncoder[A] = new AnnotatedStructTypeEncoder[A] { @@ -89,29 +94,46 @@ object AnnotatedStructTypeEncoder extends MediumPriorityImplicits { } trait LowPriorityImplicits { - implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] = AnnotatedStructTypeEncoder.pure(_ => StructType(Nil)) + implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] = + AnnotatedStructTypeEncoder.pure((_, _) => StructType(Nil)) implicit def hconsEncoder[K <: Symbol, H, T <: HList]( implicit witness: Witness.Aux[K], hEncoder: Lazy[DataTypeEncoder[H]], tEncoder: AnnotatedStructTypeEncoder[T] - ): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { metadata => + ): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { (metadata, flatten) => val fieldName = witness.value.name - val head = hEncoder.value.encode - val nullable = hEncoder.value.nullable - val tail = tEncoder.encode(metadata.tail) - StructType(StructField(fieldName, head, nullable, metadata.head) +: tail.fields) + val dt = hEncoder.value.encode + val fields = flatten.head.flatMap(f => hEncoder.value.fields.map(flattenFields(_, dt, fieldName, f))).getOrElse( + Seq(StructField(fieldName, dt, hEncoder.value.nullable, metadata.head))) + val tail = tEncoder.encode(metadata.tail, flatten.tail) + StructType(fields ++ tail.fields) } - implicit def recordEncoder[A, H <: HList, HA <: HList]( + private def flattenFields(fields: Seq[StructField], dt: DataType, prefix: String, flatten: Flatten): Seq[StructField] = + (dt, flatten) match { + case (_: ArrayType, Flatten(times, _)) if times > 1 => + (0 until times).flatMap(i => fields.map(prefixStructField(_, s"$prefix.$i"))) + case (_: MapType, Flatten(_, keys)) if keys.nonEmpty => + keys.flatMap(k => fields.map(prefixStructField(_, s"$prefix.$k"))) + case (_, Flatten(_, _)) => fields.map(prefixStructField(_, prefix)) + } + + private def prefixStructField(f: StructField, prefix: String) = + f.copy(name = s"$prefix.${f.name}") + + implicit def recordEncoder[A, H <: HList, HA <: HList, HF <: HList]( implicit generic: LabelledGeneric.Aux[A, H], - annotations: Annotations.Aux[Meta, A, HA], + metaAnnotations: Annotations.Aux[Meta, A, HA], + flattenAnnotations: Annotations.Aux[Flatten, A, HF], hEncoder: Lazy[AnnotatedStructTypeEncoder[H]], - toList: ToList[HA, Option[Meta]] + metaToList: ToList[HA, Option[Meta]], + flattenToList: ToList[HF, Option[Flatten]] ): StructTypeEncoder[A] = { - val metadata = annotations().toList[Option[Meta]].map(extractMetadata) - StructTypeEncoder.pure(hEncoder.value.encode(metadata)) + val metadata = metaAnnotations().toList[Option[Meta]].map(extractMetadata) + val flatten = flattenAnnotations().toList[Option[Flatten]] + StructTypeEncoder.pure(hEncoder.value.encode(metadata, flatten)) } private val extractMetadata: Option[Meta] => Metadata = @@ -153,16 +175,16 @@ trait MediumPriorityImplicits extends LowPriorityImplicits { enc: DataTypeEncoder[A0], is: IsTraversableOnce[C[A0]] { type A = A0 } ): DataTypeEncoder[C[A0]] = - DataTypeEncoder.pure(ArrayType(enc.encode)) + DataTypeEncoder.pure(ArrayType(enc.encode), enc.fields) implicit def mapEncoder[K, V]( implicit kEnc: DataTypeEncoder[K], vEnc: DataTypeEncoder[V] ): DataTypeEncoder[Map[K, V]] = - DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode)) + DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode), vEnc.fields) implicit def optionEncoder[V]( implicit enc: DataTypeEncoder[V] ): DataTypeEncoder[Option[V]] = - DataTypeEncoder.pure(enc.encode, true) + DataTypeEncoder.pure(enc.encode, isNullable = true) } diff --git a/core/src/main/scala/ste/selector.scala b/core/src/main/scala/ste/selector.scala new file mode 100644 index 0000000..41cb975 --- /dev/null +++ b/core/src/main/scala/ste/selector.scala @@ -0,0 +1,220 @@ +/** + * Copyright (c) 2017-2018, Benjamin Fradet, and other contributors. + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package ste + +import org.apache.spark.sql.{ Column, DataFrame, Dataset, Encoder } +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import scala.annotation.tailrec +import scala.collection.generic.IsTraversableOnce +import scala.collection.breakOut +import shapeless._ +import shapeless.ops.hlist._ +import shapeless.syntax.std.tuple._ +import shapeless.labelled.FieldType + +case class Prefix(p: String) { + def addSuffix(s: Any) = Prefix(s"$p.$s") + def getParent = Prefix(p.split("\\.").dropRight(1).mkString(".")) + def getSuffix = p.split("\\.").last + def isParentOf(other: Prefix) = other.toString.startsWith(s"$p.") + def isChildrenOf(other: Prefix) = other.isParentOf(this) + def quotedString = s"`$p`" + override def toString = p +} + +@annotation.implicitNotFound(""" + Type ${A} does not have a DataTypeSelector defined in the library. + You need to define one yourself. + """) +sealed trait DataTypeSelector[A] { + import DataTypeSelector.Select + + val select: Select +} + +object DataTypeSelector { + type Prefixes = List[Prefix] + type Select = (DataFrame, Option[Prefixes]) => DataFrame + + def pure[A](s: Select): DataTypeSelector[A] = + new DataTypeSelector[A] { + val select: Select = s + } + + def identityDF[A]: DataTypeSelector[A] = + new DataTypeSelector[A] { + val select: Select = (df, _) => df + } +} + +@annotation.implicitNotFound(""" + Type ${A} does not have a StructTypeSelector defined in the library. + You need to define one yourself. + """) +sealed trait StructTypeSelector[A] extends DataTypeSelector[A] { + import DataTypeSelector.Select + + val select: Select +} + +object StructTypeSelector extends SelectorImplicits { + import DataTypeSelector.Select + + def apply[A](implicit s: StructTypeSelector[A]): StructTypeSelector[A] = s + + def pure[A](s: Select): StructTypeSelector[A] = + new StructTypeSelector[A] { + val select: Select = s + } +} + +@annotation.implicitNotFound(""" + Type ${A} does not have a AnnotatedStructTypeSelector defined in the library. + You need to define one yourself. + """) +sealed trait AnnotatedStructTypeSelector[A] { + import AnnotatedStructTypeSelector.Select + + val select: Select +} + +object AnnotatedStructTypeSelector extends SelectorImplicits { + import DataTypeSelector.Prefixes + + type Select = (DataFrame, Option[Prefixes], Seq[Option[Flatten]]) => DataFrame + + def pure[A](s: Select): AnnotatedStructTypeSelector[A] = + new AnnotatedStructTypeSelector[A] { + val select = s + } +} + +trait SelectorImplicits { + implicit val hnilSelector: AnnotatedStructTypeSelector[HNil] = + AnnotatedStructTypeSelector.pure((df, _, _) => df) + + implicit def hconsSelector[K <: Symbol, H, T <: HList]( + implicit + witness: Witness.Aux[K], + hSelector: Lazy[DataTypeSelector[H]], + tSelector: AnnotatedStructTypeSelector[T] + ): AnnotatedStructTypeSelector[FieldType[K, H] :: T] = AnnotatedStructTypeSelector.pure { (df, parentPrefixes, flatten) => + val fieldName = witness.value.name + val prefixes = parentPrefixes.map(_.map(_.addSuffix(fieldName))).getOrElse(List(Prefix(fieldName))) + val childPrefixes = getChildPrefixes(prefixes, flatten.head) + val dfHead = hSelector.value.select(df, Some(childPrefixes)) + val dfNested = flatten.head.map { fl => + val fields = dfHead.schema.fields.map(f => Prefix(f.name)).toList + val restCols = fields.filter(f => !childPrefixes.exists(_.isParentOf(f))).map(f => dfHead(f.quotedString)) + val structs = childPrefixes.map { p => + val cols = fields.filter(_.isChildrenOf(p)).map(f => dfHead(f.quotedString).as(f.getSuffix)) + struct(cols :_*).as(p.toString) + } + val dfStruct = dfHead.select((structs ++ restCols) :_*) + val nestedCols = getNestedColumns(childPrefixes, dfStruct, fl) + orderedSelect(dfStruct, nestedCols, fields) + }.getOrElse(dfHead) + tSelector.select(dfNested, parentPrefixes, flatten.tail) + } + + private def getChildPrefixes(prefixes: List[Prefix], flatten: Option[Flatten]): List[Prefix] = + flatten.map { + case Flatten(times, _) if times > 1 => (0 until times).flatMap(i => prefixes.map(_.addSuffix(i))).toList + case Flatten(_, keys) if keys.nonEmpty => keys.flatMap(k => prefixes.map(_.addSuffix(k))).toList + case Flatten(_, _) => prefixes + }.getOrElse(prefixes) + + private def getNestedColumns(prefixes: List[Prefix], df: DataFrame, flatten: Flatten): Map[Prefix, Column] = + prefixes.groupBy(_.getParent).map { case (prefix, groupedPrefixes) => + val colName = prefix.toString + val cols = groupedPrefixes.map(p => df(p.quotedString)) + flatten match { + case Flatten(times, _) if times > 1 => (prefix, array(cols :_*).as(colName)) + case Flatten(_, keys) if keys.nonEmpty => (prefix, map(interleave(keys.map(lit), cols) :_*).as(colName)) + case Flatten(_, _) => (groupedPrefixes.head, cols.head) + } + }(breakOut) + + private def orderedSelect(df: DataFrame, nestedCols: Map[Prefix, Column], fields: List[Prefix]): DataFrame = { + @tailrec + def loop(nestedCols: Map[Prefix, Column], fields: List[Prefix], cols: List[Column]): List[Column] = fields match { + case Nil => cols.reverse + case hd +: tail => nestedCols.find { case (p, _) => p.isParentOf(hd) } match { + case Some((p, c)) => loop(nestedCols - p, fields.dropWhile(_.isChildrenOf(p)), c +: cols) + case None => loop(nestedCols, tail, df(hd.quotedString) +: cols) + } + } + val cols = loop(nestedCols, fields, List[Column]()) + df.select(cols :_*) + } + + private def interleave[T](a: Seq[T], b: Seq[T]): Seq[T] = a.zip(b).flatMap(_.toList) + + implicit def dfSelector[A, H <: HList, HF <: HList]( + implicit + generic: LabelledGeneric.Aux[A, H], + flattenAnnotations: Annotations.Aux[Flatten, A, HF], + hSelector: Lazy[AnnotatedStructTypeSelector[H]], + flattenToList: ToList[HF, Option[Flatten]] + ): StructTypeSelector[A] = StructTypeSelector.pure { (df, prefixes) => + val flatten = flattenAnnotations().toList[Option[Flatten]] + hSelector.value.select(df, prefixes, flatten) + } + + implicit val binarySelector: DataTypeSelector[Array[Byte]] = DataTypeSelector.identityDF + implicit val booleanSelector: DataTypeSelector[Boolean] = DataTypeSelector.identityDF + implicit val byteSelector: DataTypeSelector[Byte] = DataTypeSelector.identityDF + implicit val dateSelector: DataTypeSelector[java.sql.Date] = DataTypeSelector.identityDF + implicit val decimalSelector: DataTypeSelector[BigDecimal] = DataTypeSelector.identityDF + implicit val doubleSelector: DataTypeSelector[Double] = DataTypeSelector.identityDF + implicit val floatSelector: DataTypeSelector[Float] = DataTypeSelector.identityDF + implicit val intSelector: DataTypeSelector[Int] = DataTypeSelector.identityDF + implicit val longSelector: DataTypeSelector[Long] = DataTypeSelector.identityDF + implicit val nullSelector: DataTypeSelector[Unit] = DataTypeSelector.identityDF + implicit val shortSelector: DataTypeSelector[Short] = DataTypeSelector.identityDF + implicit val stringSelector: DataTypeSelector[String] = DataTypeSelector.identityDF + implicit val timestampSelector: DataTypeSelector[java.sql.Timestamp] = DataTypeSelector.identityDF + implicit def optionSelector[T]: DataTypeSelector[Option[T]] = DataTypeSelector.identityDF + + implicit def traversableOnceSelector[A0, C[_]]( + implicit + s: DataTypeSelector[A0], + is: IsTraversableOnce[C[A0]] { type A = A0 } + ): DataTypeSelector[C[A0]] = DataTypeSelector.pure { (df, prefixes) => + s.select(df, prefixes) + } + + implicit def mapSelector[K, V]( + implicit s: DataTypeSelector[V] + ): DataTypeSelector[Map[K, V]] = DataTypeSelector.pure { (df, prefixes) => + s.select(df, prefixes) + } +} + +object DFUtils { + implicit class FlattenedDataFrame(df: DataFrame) { + def asNested[A : Encoder : StructTypeSelector]: Dataset[A] = selectNested.as[A] + + def selectNested[A](implicit s: StructTypeSelector[A]): DataFrame = s.select(df, None) + } +} diff --git a/core/src/test/scala/ste/StructTypeEncoderSpec.scala b/core/src/test/scala/ste/StructTypeEncoderSpec.scala index 90e904e..ac72784 100644 --- a/core/src/test/scala/ste/StructTypeEncoderSpec.scala +++ b/core/src/test/scala/ste/StructTypeEncoderSpec.scala @@ -115,9 +115,18 @@ class StructTypeEncoderSpec extends FlatSpec with Matchers { .build case class Foo(a: String, @Meta(metadata) b: Int) - StructTypeEncoder[Foo].encode shouldBe StructType( - StructField("a", StringType, false) :: - StructField("b", IntegerType, false, metadata) :: Nil + case class Bar(@Flatten(2) a: Seq[Foo], @Flatten(1, Seq("x", "y")) b: Map[String, Foo], @Flatten c: Foo) + StructTypeEncoder[Bar].encode shouldBe StructType( + StructField("a.0.a", StringType, false) :: + StructField("a.0.b", IntegerType, false, metadata) :: + StructField("a.1.a", StringType, false) :: + StructField("a.1.b", IntegerType, false, metadata) :: + StructField("b.x.a", StringType, false) :: + StructField("b.x.b", IntegerType, false, metadata) :: + StructField("b.y.a", StringType, false) :: + StructField("b.y.b", IntegerType, false, metadata) :: + StructField("c.a", StringType, false) :: + StructField("c.b", IntegerType, false, metadata) :: Nil ) } } diff --git a/core/src/test/scala/ste/StructTypeSelectorSpec.scala b/core/src/test/scala/ste/StructTypeSelectorSpec.scala new file mode 100644 index 0000000..6f8c4bb --- /dev/null +++ b/core/src/test/scala/ste/StructTypeSelectorSpec.scala @@ -0,0 +1,92 @@ +/** + * Copyright (c) 2017-2018, Benjamin Fradet, and other contributors. + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package ste + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.SparkSession +import org.scalatest.{ FlatSpec, Matchers } +import ste._ +import StructTypeEncoder._ +import StructTypeSelector._ +import DFUtils._ + +object StructSelectorSpec { + case class Foo(a: Int, b: String) + case class Bar(@Flatten(1, Seq("asd", "qwe")) foo: Map[String, Foo], c: Int) + case class Baz(@Flatten(2) bar: Seq[Bar], e: Int) + case class Asd(@Flatten foo: Foo, x: Int) +} + +class StructSelectorSpec extends FlatSpec with Matchers { + import StructSelectorSpec._ + val spark = SparkSession.builder().master("local").getOrCreate() + + "selectNested" should "return the nested DataFrame" in { + import spark.implicits._ + val values = List((1, "a", 2, "b", 3), (4, "c", 5, "d", 6)) + val df = values.toDF(StructTypeEncoder[Bar].encode.fields.map(_.name) :_*) + val result = df.selectNested[Bar] + val expected = Array( + Bar(Map("asd" -> Foo(1, "a"), "qwe" -> Foo(2, "b")), 3), + Bar(Map("asd" -> Foo(4, "c"), "qwe" -> Foo(5, "d")), 6) + ) + result.as[Bar].collect shouldEqual expected + } + + it should "deal with flattened struct" in { + import spark.implicits._ + val values = List((1, "a", 2), (3, "b", 4)) + val df = values.toDF(StructTypeEncoder[Asd].encode.fields.map(_.name) :_*) + val result = df.asNested[Asd].collect + val expected = Array( + Asd(Foo(1, "a"), 2), Asd(Foo(3, "b"), 4) + ) + result shouldEqual expected + } + + it should "deal with deep nested structures" in { + import spark.implicits._ + val values = List( + (1, "a", 2, "b", 3, 4, "c", 5, "d", 6, 7), + (10, "aa", 20, "bb", 30, 40, "cc", 50, "dd", 60, 70) + ) + val df = values.toDF(StructTypeEncoder[Baz].encode.fields.map(_.name) :_*) + val result = df.asNested[Baz].collect + val expected = Array( + Baz( + Seq( + Bar(Map("asd" -> Foo(1, "a"), "qwe" -> Foo(2, "b")), 3), + Bar(Map("asd" -> Foo(4, "c"), "qwe" -> Foo(5, "d")), 6) + ), + 7 + ), + Baz( + Seq( + Bar(Map("asd" -> Foo(10, "aa"), "qwe" -> Foo(20, "bb")), 30), + Bar(Map("asd" -> Foo(40, "cc"), "qwe" -> Foo(50, "dd")), 60) + ), + 70 + ) + ) + result shouldEqual expected + } +}