Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ lazy val compilerOptions = Seq(
)

lazy val shapelessVersion = "2.3.2"
lazy val sparkVersion = "2.1.0"
lazy val sparkVersion = "2.1.2"
lazy val scalatestVersion = "3.0.1"

lazy val baseSettings = Seq(
Expand All @@ -28,6 +28,7 @@ lazy val baseSettings = Seq(
) ++ Seq(
"org.scalatest" %% "scalatest" % scalatestVersion
).map(_ % "test"),
parallelExecution in Test := false,
scalacOptions ++= compilerOptions
)

Expand Down
56 changes: 39 additions & 17 deletions core/src/main/scala/ste/encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,26 @@ import scala.annotation.StaticAnnotation

import scala.collection.generic.IsTraversableOnce

final class Meta(val metadata: Metadata) extends StaticAnnotation
final case class Meta(metadata: Metadata) extends StaticAnnotation
final case class Flatten(times: Int = 1, keys: Seq[String] = Seq()) extends StaticAnnotation

@annotation.implicitNotFound("""
Type ${A} does not have a DataTypeEncoder defined in the library.
You need to define one yourself.
""")
sealed trait DataTypeEncoder[A] {
def encode: DataType
def fields: Option[Seq[StructField]]
def nullable: Boolean
}

object DataTypeEncoder {
def apply[A](implicit enc: DataTypeEncoder[A]): DataTypeEncoder[A] = enc

def pure[A](dt: DataType, isNullable: Boolean = false): DataTypeEncoder[A] =
def pure[A](dt: DataType, f: Option[Seq[StructField]] = None, isNullable: Boolean = false): DataTypeEncoder[A] =
new DataTypeEncoder[A] {
def encode: DataType = dt
def fields: Option[Seq[StructField]] = f
def nullable: Boolean = isNullable
}
}
Expand All @@ -56,6 +59,7 @@ object DataTypeEncoder {
""")
sealed trait StructTypeEncoder[A] extends DataTypeEncoder[A] {
def encode: StructType
def fields: Option[Seq[StructField]]
def nullable: Boolean
}

Expand All @@ -65,6 +69,7 @@ object StructTypeEncoder extends MediumPriorityImplicits {
def pure[A](st: StructType, isNullable: Boolean = false): StructTypeEncoder[A] =
new StructTypeEncoder[A] {
def encode: StructType = st
def fields: Option[Seq[StructField]] = Some(st.fields.toSeq)
def nullable: Boolean = isNullable
}
}
Expand All @@ -80,7 +85,7 @@ sealed trait AnnotatedStructTypeEncoder[A] {
}

object AnnotatedStructTypeEncoder extends MediumPriorityImplicits {
type Encode = Seq[Metadata] => StructType
type Encode = (Seq[Metadata], Seq[Option[Flatten]]) => StructType

def pure[A](enc: Encode): AnnotatedStructTypeEncoder[A] =
new AnnotatedStructTypeEncoder[A] {
Expand All @@ -89,29 +94,46 @@ object AnnotatedStructTypeEncoder extends MediumPriorityImplicits {
}

trait LowPriorityImplicits {
implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] = AnnotatedStructTypeEncoder.pure(_ => StructType(Nil))
implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] =
AnnotatedStructTypeEncoder.pure((_, _) => StructType(Nil))
implicit def hconsEncoder[K <: Symbol, H, T <: HList](
implicit
witness: Witness.Aux[K],
hEncoder: Lazy[DataTypeEncoder[H]],
tEncoder: AnnotatedStructTypeEncoder[T]
): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { metadata =>
): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { (metadata, flatten) =>
val fieldName = witness.value.name
val head = hEncoder.value.encode
val nullable = hEncoder.value.nullable
val tail = tEncoder.encode(metadata.tail)
StructType(StructField(fieldName, head, nullable, metadata.head) +: tail.fields)
val dt = hEncoder.value.encode
val fields = flatten.head.flatMap(f => hEncoder.value.fields.map(flattenFields(_, dt, fieldName, f))).getOrElse(
Seq(StructField(fieldName, dt, hEncoder.value.nullable, metadata.head)))
val tail = tEncoder.encode(metadata.tail, flatten.tail)
StructType(fields ++ tail.fields)
}

implicit def recordEncoder[A, H <: HList, HA <: HList](
private def flattenFields(fields: Seq[StructField], dt: DataType, prefix: String, flatten: Flatten): Seq[StructField] =
(dt, flatten) match {
case (_: ArrayType, Flatten(times, _)) if times > 1 =>
(0 until times).flatMap(i => fields.map(prefixStructField(_, s"$prefix.$i")))
case (_: MapType, Flatten(_, keys)) if keys.nonEmpty =>
keys.flatMap(k => fields.map(prefixStructField(_, s"$prefix.$k")))
case (_, Flatten(_, _)) => fields.map(prefixStructField(_, prefix))
}

private def prefixStructField(f: StructField, prefix: String) =
f.copy(name = s"$prefix.${f.name}")

implicit def recordEncoder[A, H <: HList, HA <: HList, HF <: HList](
implicit
generic: LabelledGeneric.Aux[A, H],
annotations: Annotations.Aux[Meta, A, HA],
metaAnnotations: Annotations.Aux[Meta, A, HA],
flattenAnnotations: Annotations.Aux[Flatten, A, HF],
hEncoder: Lazy[AnnotatedStructTypeEncoder[H]],
toList: ToList[HA, Option[Meta]]
metaToList: ToList[HA, Option[Meta]],
flattenToList: ToList[HF, Option[Flatten]]
): StructTypeEncoder[A] = {
val metadata = annotations().toList[Option[Meta]].map(extractMetadata)
StructTypeEncoder.pure(hEncoder.value.encode(metadata))
val metadata = metaAnnotations().toList[Option[Meta]].map(extractMetadata)
val flatten = flattenAnnotations().toList[Option[Flatten]]
StructTypeEncoder.pure(hEncoder.value.encode(metadata, flatten))
}

private val extractMetadata: Option[Meta] => Metadata =
Expand Down Expand Up @@ -153,16 +175,16 @@ trait MediumPriorityImplicits extends LowPriorityImplicits {
enc: DataTypeEncoder[A0],
is: IsTraversableOnce[C[A0]] { type A = A0 }
): DataTypeEncoder[C[A0]] =
DataTypeEncoder.pure(ArrayType(enc.encode))
DataTypeEncoder.pure(ArrayType(enc.encode), enc.fields)
implicit def mapEncoder[K, V](
implicit
kEnc: DataTypeEncoder[K],
vEnc: DataTypeEncoder[V]
): DataTypeEncoder[Map[K, V]] =
DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode))
DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode), vEnc.fields)
implicit def optionEncoder[V](
implicit
enc: DataTypeEncoder[V]
): DataTypeEncoder[Option[V]] =
DataTypeEncoder.pure(enc.encode, true)
DataTypeEncoder.pure(enc.encode, isNullable = true)
}
220 changes: 220 additions & 0 deletions core/src/main/scala/ste/selector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/**
* Copyright (c) 2017-2018, Benjamin Fradet, and other contributors.
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package ste
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add the license header in this file and the associated spec?


import org.apache.spark.sql.{ Column, DataFrame, Dataset, Encoder }
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.annotation.tailrec
import scala.collection.generic.IsTraversableOnce
import scala.collection.breakOut
import shapeless._
import shapeless.ops.hlist._
import shapeless.syntax.std.tuple._
import shapeless.labelled.FieldType

case class Prefix(p: String) {
def addSuffix(s: Any) = Prefix(s"$p.$s")
def getParent = Prefix(p.split("\\.").dropRight(1).mkString("."))
def getSuffix = p.split("\\.").last
def isParentOf(other: Prefix) = other.toString.startsWith(s"$p.")
def isChildrenOf(other: Prefix) = other.isParentOf(this)
def quotedString = s"`$p`"
override def toString = p
}

@annotation.implicitNotFound("""
Type ${A} does not have a DataTypeSelector defined in the library.
You need to define one yourself.
""")
sealed trait DataTypeSelector[A] {
import DataTypeSelector.Select

val select: Select
}

object DataTypeSelector {
type Prefixes = List[Prefix]
type Select = (DataFrame, Option[Prefixes]) => DataFrame

def pure[A](s: Select): DataTypeSelector[A] =
new DataTypeSelector[A] {
val select: Select = s
}

def identityDF[A]: DataTypeSelector[A] =
new DataTypeSelector[A] {
val select: Select = (df, _) => df
}
}

@annotation.implicitNotFound("""
Type ${A} does not have a StructTypeSelector defined in the library.
You need to define one yourself.
""")
sealed trait StructTypeSelector[A] extends DataTypeSelector[A] {
import DataTypeSelector.Select

val select: Select
}

object StructTypeSelector extends SelectorImplicits {
import DataTypeSelector.Select

def apply[A](implicit s: StructTypeSelector[A]): StructTypeSelector[A] = s

def pure[A](s: Select): StructTypeSelector[A] =
new StructTypeSelector[A] {
val select: Select = s
}
}

@annotation.implicitNotFound("""
Type ${A} does not have a AnnotatedStructTypeSelector defined in the library.
You need to define one yourself.
""")
sealed trait AnnotatedStructTypeSelector[A] {
import AnnotatedStructTypeSelector.Select

val select: Select
}

object AnnotatedStructTypeSelector extends SelectorImplicits {
import DataTypeSelector.Prefixes

type Select = (DataFrame, Option[Prefixes], Seq[Option[Flatten]]) => DataFrame

def pure[A](s: Select): AnnotatedStructTypeSelector[A] =
new AnnotatedStructTypeSelector[A] {
val select = s
}
}

trait SelectorImplicits {
implicit val hnilSelector: AnnotatedStructTypeSelector[HNil] =
AnnotatedStructTypeSelector.pure((df, _, _) => df)

implicit def hconsSelector[K <: Symbol, H, T <: HList](
implicit
witness: Witness.Aux[K],
hSelector: Lazy[DataTypeSelector[H]],
tSelector: AnnotatedStructTypeSelector[T]
): AnnotatedStructTypeSelector[FieldType[K, H] :: T] = AnnotatedStructTypeSelector.pure { (df, parentPrefixes, flatten) =>
val fieldName = witness.value.name
val prefixes = parentPrefixes.map(_.map(_.addSuffix(fieldName))).getOrElse(List(Prefix(fieldName)))
val childPrefixes = getChildPrefixes(prefixes, flatten.head)
val dfHead = hSelector.value.select(df, Some(childPrefixes))
val dfNested = flatten.head.map { fl =>
val fields = dfHead.schema.fields.map(f => Prefix(f.name)).toList
val restCols = fields.filter(f => !childPrefixes.exists(_.isParentOf(f))).map(f => dfHead(f.quotedString))
val structs = childPrefixes.map { p =>
val cols = fields.filter(_.isChildrenOf(p)).map(f => dfHead(f.quotedString).as(f.getSuffix))
struct(cols :_*).as(p.toString)
}
val dfStruct = dfHead.select((structs ++ restCols) :_*)
val nestedCols = getNestedColumns(childPrefixes, dfStruct, fl)
orderedSelect(dfStruct, nestedCols, fields)
}.getOrElse(dfHead)
tSelector.select(dfNested, parentPrefixes, flatten.tail)
}

private def getChildPrefixes(prefixes: List[Prefix], flatten: Option[Flatten]): List[Prefix] =
flatten.map {
case Flatten(times, _) if times > 1 => (0 until times).flatMap(i => prefixes.map(_.addSuffix(i))).toList
case Flatten(_, keys) if keys.nonEmpty => keys.flatMap(k => prefixes.map(_.addSuffix(k))).toList
case Flatten(_, _) => prefixes
}.getOrElse(prefixes)

private def getNestedColumns(prefixes: List[Prefix], df: DataFrame, flatten: Flatten): Map[Prefix, Column] =
prefixes.groupBy(_.getParent).map { case (prefix, groupedPrefixes) =>
val colName = prefix.toString
val cols = groupedPrefixes.map(p => df(p.quotedString))
flatten match {
case Flatten(times, _) if times > 1 => (prefix, array(cols :_*).as(colName))
case Flatten(_, keys) if keys.nonEmpty => (prefix, map(interleave(keys.map(lit), cols) :_*).as(colName))
case Flatten(_, _) => (groupedPrefixes.head, cols.head)
}
}(breakOut)

private def orderedSelect(df: DataFrame, nestedCols: Map[Prefix, Column], fields: List[Prefix]): DataFrame = {
@tailrec
def loop(nestedCols: Map[Prefix, Column], fields: List[Prefix], cols: List[Column]): List[Column] = fields match {
case Nil => cols.reverse
case hd +: tail => nestedCols.find { case (p, _) => p.isParentOf(hd) } match {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you have lists you can do hd :: tail, same thing below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, unfortunately shapeless overrides the :: definition

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, too bad :(

case Some((p, c)) => loop(nestedCols - p, fields.dropWhile(_.isChildrenOf(p)), c +: cols)
case None => loop(nestedCols, tail, df(hd.quotedString) +: cols)
}
}
val cols = loop(nestedCols, fields, List[Column]())
df.select(cols :_*)
}

private def interleave[T](a: Seq[T], b: Seq[T]): Seq[T] = a.zip(b).flatMap(_.toList)

implicit def dfSelector[A, H <: HList, HF <: HList](
implicit
generic: LabelledGeneric.Aux[A, H],
flattenAnnotations: Annotations.Aux[Flatten, A, HF],
hSelector: Lazy[AnnotatedStructTypeSelector[H]],
flattenToList: ToList[HF, Option[Flatten]]
): StructTypeSelector[A] = StructTypeSelector.pure { (df, prefixes) =>
val flatten = flattenAnnotations().toList[Option[Flatten]]
hSelector.value.select(df, prefixes, flatten)
}

implicit val binarySelector: DataTypeSelector[Array[Byte]] = DataTypeSelector.identityDF
implicit val booleanSelector: DataTypeSelector[Boolean] = DataTypeSelector.identityDF
implicit val byteSelector: DataTypeSelector[Byte] = DataTypeSelector.identityDF
implicit val dateSelector: DataTypeSelector[java.sql.Date] = DataTypeSelector.identityDF
implicit val decimalSelector: DataTypeSelector[BigDecimal] = DataTypeSelector.identityDF
implicit val doubleSelector: DataTypeSelector[Double] = DataTypeSelector.identityDF
implicit val floatSelector: DataTypeSelector[Float] = DataTypeSelector.identityDF
implicit val intSelector: DataTypeSelector[Int] = DataTypeSelector.identityDF
implicit val longSelector: DataTypeSelector[Long] = DataTypeSelector.identityDF
implicit val nullSelector: DataTypeSelector[Unit] = DataTypeSelector.identityDF
implicit val shortSelector: DataTypeSelector[Short] = DataTypeSelector.identityDF
implicit val stringSelector: DataTypeSelector[String] = DataTypeSelector.identityDF
implicit val timestampSelector: DataTypeSelector[java.sql.Timestamp] = DataTypeSelector.identityDF
implicit def optionSelector[T]: DataTypeSelector[Option[T]] = DataTypeSelector.identityDF

implicit def traversableOnceSelector[A0, C[_]](
implicit
s: DataTypeSelector[A0],
is: IsTraversableOnce[C[A0]] { type A = A0 }
): DataTypeSelector[C[A0]] = DataTypeSelector.pure { (df, prefixes) =>
s.select(df, prefixes)
}

implicit def mapSelector[K, V](
implicit s: DataTypeSelector[V]
): DataTypeSelector[Map[K, V]] = DataTypeSelector.pure { (df, prefixes) =>
s.select(df, prefixes)
}
}

object DFUtils {
implicit class FlattenedDataFrame(df: DataFrame) {
def asNested[A : Encoder : StructTypeSelector]: Dataset[A] = selectNested.as[A]

def selectNested[A](implicit s: StructTypeSelector[A]): DataFrame = s.select(df, None)
}
}
15 changes: 12 additions & 3 deletions core/src/test/scala/ste/StructTypeEncoderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,18 @@ class StructTypeEncoderSpec extends FlatSpec with Matchers {
.build

case class Foo(a: String, @Meta(metadata) b: Int)
StructTypeEncoder[Foo].encode shouldBe StructType(
StructField("a", StringType, false) ::
StructField("b", IntegerType, false, metadata) :: Nil
case class Bar(@Flatten(2) a: Seq[Foo], @Flatten(1, Seq("x", "y")) b: Map[String, Foo], @Flatten c: Foo)
StructTypeEncoder[Bar].encode shouldBe StructType(
StructField("a.0.a", StringType, false) ::
StructField("a.0.b", IntegerType, false, metadata) ::
StructField("a.1.a", StringType, false) ::
StructField("a.1.b", IntegerType, false, metadata) ::
StructField("b.x.a", StringType, false) ::
StructField("b.x.b", IntegerType, false, metadata) ::
StructField("b.y.a", StringType, false) ::
StructField("b.y.b", IntegerType, false, metadata) ::
StructField("c.a", StringType, false) ::
StructField("c.b", IntegerType, false, metadata) :: Nil
)
}
}
Loading