Skip to content

Commit

Permalink
Merge pull request #1131 from twitter/ianoc/caseClassTupleConverters
Browse files Browse the repository at this point in the history
Ianoc/case class tuple converters
  • Loading branch information
johnynek committed Dec 16, 2014
2 parents c81d9e1 + af7cb52 commit e9b4f44
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 1 deletion.
13 changes: 12 additions & 1 deletion project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object ScaldingBuild extends Build {
val scalaCheckVersion = "1.11.5"
val hadoopVersion = "1.2.1"
val algebirdVersion = "0.8.2"
val bijectionVersion = "0.7.0"
val bijectionVersion = "0.7.1"
val chillVersion = "0.5.1"
val slf4jVersion = "1.6.6"
val parquetVersion = "1.6.0rc4"
Expand Down Expand Up @@ -200,6 +200,7 @@ object ScaldingBuild extends Build {
scaldingJson,
scaldingJdbc,
scaldingHadoopTest,
scaldingMacros,
maple
)

Expand Down Expand Up @@ -401,6 +402,16 @@ object ScaldingBuild extends Build {
}
).dependsOn(scaldingCore)

lazy val scaldingMacros = module("macros").settings(
libraryDependencies <++= (scalaVersion) { scalaVersion => Seq(
"org.scala-lang" % "scala-library" % scalaVersion,
"org.scala-lang" % "scala-reflect" % scalaVersion,
"com.twitter" %% "bijection-macros" % bijectionVersion
) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % "2.0.1") else Seq())
},
addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full)
).dependsOn(scaldingCore, scaldingHadoopTest)

// This one uses a different naming convention
lazy val maple = Project(
id = "maple",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.twitter.scalding.macros

import scala.language.experimental.macros

import com.twitter.scalding._
import com.twitter.scalding.macros.impl.MacroImpl
import com.twitter.bijection.macros.IsCaseClass

object MacroImplicits {
/**
* This method provides proof that the given type is a case class.
*/
implicit def materializeCaseClassTupleSetter[T: IsCaseClass]: TupleSetter[T] = macro MacroImpl.caseClassTupleSetterImpl[T]
implicit def materializeCaseClassTupleConverter[T: IsCaseClass]: TupleConverter[T] = macro MacroImpl.caseClassTupleConverterImpl[T]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.twitter.scalding.macros

import scala.language.experimental.macros

import com.twitter.scalding._
import com.twitter.scalding.macros.impl.MacroImpl
import com.twitter.bijection.macros.IsCaseClass

object Macros {
// These only work for simple types inside the case class
// Nested case classes are allowed, but only: Int, Boolean, String, Long, Short, Float, Double of other types are allowed
def caseClassTupleSetter[T: IsCaseClass]: TupleSetter[T] = macro MacroImpl.caseClassTupleSetterImpl[T]
def caseClassTupleConverter[T: IsCaseClass]: TupleConverter[T] = macro MacroImpl.caseClassTupleConverterImpl[T]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package com.twitter.scalding.macros.impl

import scala.collection.mutable.{ Map => MMap }
import scala.language.experimental.macros
import scala.reflect.macros.Context
import scala.reflect.runtime.universe._
import scala.util.{ Try => BasicTry }

import cascading.tuple.{ Tuple, TupleEntry }

import com.twitter.scalding._
import com.twitter.bijection.macros.{ IsCaseClass, MacroGenerated }
import com.twitter.bijection.macros.impl.IsCaseClassImpl
/**
* This class contains the core macro implementations. This is in a separate module to allow it to be in
* a separate compilation unit, which makes it easier to provide helper methods interfacing with macros.
*/
object MacroImpl {
def caseClassTupleSetterNoProof[T]: TupleSetter[T] = macro caseClassTupleSetterNoProofImpl[T]

def caseClassTupleSetterImpl[T](c: Context)(proof: c.Expr[IsCaseClass[T]])(implicit T: c.WeakTypeTag[T]): c.Expr[TupleSetter[T]] =
caseClassTupleSetterNoProofImpl(c)(T)

def caseClassTupleSetterNoProofImpl[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[TupleSetter[T]] = {
import c.universe._
//TODO get rid of the mutability
val cachedTupleSetters: MMap[Type, Int] = MMap.empty
var cacheIdx = 0

def expandMethod(outerTpe: Type, pTree: Tree): Iterable[Int => Tree] = {
outerTpe
.declarations
.collect { case m: MethodSymbol if m.isCaseAccessor => m }
.flatMap { accessorMethod =>
accessorMethod.returnType match {
case tpe if tpe =:= typeOf[String] => List((idx: Int) => q"""tup.setString(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Boolean] => List((idx: Int) => q"""tup.setBoolean(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Short] => List((idx: Int) => q"""tup.setShort(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Int] => List((idx: Int) => q"""tup.setInteger(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Long] => List((idx: Int) => q"""tup.setLong(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Float] => List((idx: Int) => q"""tup.setFloat(${idx}, $pTree.$accessorMethod)""")
case tpe if tpe =:= typeOf[Double] => List((idx: Int) => q"""tup.setDouble(${idx}, $pTree.$accessorMethod)""")
case tpe if IsCaseClassImpl.isCaseClassType(c)(tpe) =>
expandMethod(tpe,
q"""$pTree.$accessorMethod""")
case _ => c.abort(c.enclosingPosition, s"Case class ${T} is not pure primitives or nested case classes")
}
}
}

val set =
expandMethod(T.tpe, q"t")
.zipWithIndex
.map {
case (treeGenerator, idx) =>
treeGenerator(idx)
}

val res = q"""
new _root_.com.twitter.scalding.TupleSetter[$T] with _root_.com.twitter.bijection.macros.MacroGenerated {
override def apply(t: $T): _root_.cascading.tuple.Tuple = {
val tup = _root_.cascading.tuple.Tuple.size(${set.size})
..$set
tup
}
override val arity: scala.Int = ${set.size}
}
"""
c.Expr[TupleSetter[T]](res)
}

def caseClassTupleConverterNoProof[T]: TupleConverter[T] = macro caseClassTupleConverterNoProofImpl[T]

def caseClassTupleConverterImpl[T](c: Context)(proof: c.Expr[IsCaseClass[T]])(implicit T: c.WeakTypeTag[T]): c.Expr[TupleConverter[T]] =
caseClassTupleConverterNoProofImpl(c)(T)

def caseClassTupleConverterNoProofImpl[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[TupleConverter[T]] = {
import c.universe._
//TODO get rid of the mutability
val cachedTupleConverters: MMap[Type, Int] = MMap.empty
var cacheIdx = 0
case class AccessorBuilder(builder: Tree, size: Int)

def getPrimitive(strAccessor: Tree): Int => AccessorBuilder =
{ (idx: Int) =>
AccessorBuilder(q"""${strAccessor}(${idx})""", 1)
}

def flattenAccessorBuilders(tpe: Type, idx: Int, childGetters: List[(Int => AccessorBuilder)]): AccessorBuilder = {
val (_, accessors) = childGetters.foldLeft((idx, List[AccessorBuilder]())) {
case ((curIdx, eles), t) =>
val nextEle = t(curIdx)
val idxIncr = nextEle.size
(curIdx + idxIncr, eles :+ nextEle)
}

val builder = q"""
${tpe.typeSymbol.companionSymbol}(..${accessors.map(_.builder)})
"""
val size = accessors.map(_.size).reduce(_ + _)
AccessorBuilder(builder, size)
}

def expandMethod(outerTpe: Type): List[(Int => AccessorBuilder)] = {
outerTpe.declarations
.collect { case m: MethodSymbol if m.isCaseAccessor => m.returnType }
.toList
.map { accessorMethod =>
accessorMethod match {
case tpe if tpe =:= typeOf[String] => getPrimitive(q"t.getString")
case tpe if tpe =:= typeOf[Boolean] => getPrimitive(q"t.Boolean")
case tpe if tpe =:= typeOf[Short] => getPrimitive(q"t.getShort")
case tpe if tpe =:= typeOf[Int] => getPrimitive(q"t.getInteger")
case tpe if tpe =:= typeOf[Long] => getPrimitive(q"t.getLong")
case tpe if tpe =:= typeOf[Float] => getPrimitive(q"t.getFloat")
case tpe if tpe =:= typeOf[Double] => getPrimitive(q"t.getDouble")
case tpe if IsCaseClassImpl.isCaseClassType(c)(tpe) =>
{ (idx: Int) =>
{
val childGetters = expandMethod(tpe)
flattenAccessorBuilders(tpe, idx, childGetters)
}
}
case _ => c.abort(c.enclosingPosition, s"Case class ${T} is not pure primitives or nested case classes")
}
}
}

val accessorBuilders = flattenAccessorBuilders(T.tpe, 0, expandMethod(T.tpe))

val res = q"""
new _root_.com.twitter.scalding.TupleConverter[$T] with _root_.com.twitter.bijection.macros.MacroGenerated {
override def apply(t: _root_.cascading.tuple.TupleEntry): $T = {
${accessorBuilders.builder}
}
override val arity: scala.Int = ${accessorBuilders.size}
}
"""
c.Expr[TupleConverter[T]](res)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.twitter.scalding.macros

import org.scalatest.WordSpec
import com.twitter.scalding.macros.{ _ => _ }

/**
* This test is intended to ensure that the macros do not require any imported code in scope. This is why all
* references are via absolute paths.
*/
class MacroDepHygiene extends WordSpec {
import com.twitter.bijection.macros.impl.IsCaseClassImpl

case class A(x: Int, y: String)
case class B(x: A, y: String, z: A)
class C

def isMg(a: Any) = a.isInstanceOf[com.twitter.bijection.macros.MacroGenerated]

"TupleSetter macro" should {
def isTupleSetterAvailable[T](implicit proof: com.twitter.scalding.TupleSetter[T]) = isMg(proof)

"work fine without any imports" in {
com.twitter.scalding.macros.Macros.caseClassTupleSetter[A]
com.twitter.scalding.macros.Macros.caseClassTupleSetter[B]
}

"implicitly work fine without any imports" in {
import com.twitter.scalding.macros.MacroImplicits.materializeCaseClassTupleSetter
assert(isTupleSetterAvailable[A])
assert(isTupleSetterAvailable[B])
}

"fail if not a case class" in {
assert(!isTupleSetterAvailable[C])
}
}

"TupleConverter macro" should {
def isTupleConverterAvailable[T](implicit proof: com.twitter.scalding.TupleConverter[T]) = isMg(proof)

"work fine without any imports" in {
com.twitter.scalding.macros.Macros.caseClassTupleConverter[A]
com.twitter.scalding.macros.Macros.caseClassTupleConverter[B]
}

"implicitly work fine without any imports" in {
import com.twitter.scalding.macros.MacroImplicits.materializeCaseClassTupleConverter
assert(isTupleConverterAvailable[A])
assert(isTupleConverterAvailable[B])
}

"fail if not a case class" in {
assert(!isTupleConverterAvailable[C])
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package com.twitter.scalding.macros

import cascading.tuple.{ Tuple => CTuple, TupleEntry }

import org.scalatest.{ Matchers, WordSpec }

import com.twitter.scalding._
import com.twitter.scalding.macros._
import com.twitter.scalding.macros.impl._
import com.twitter.scalding.serialization.Externalizer

import com.twitter.bijection.macros.{ IsCaseClass, MacroGenerated }

// We avoid nesting these just to avoid any complications in the serialization test
case class SampleClassA(x: Int, y: String)
case class SampleClassB(a1: SampleClassA, a2: SampleClassA, y: String)
case class SampleClassC(a: SampleClassA, b: SampleClassB, c: SampleClassA, d: SampleClassB, e: SampleClassB)

class MacrosUnitTests extends WordSpec with Matchers {
import MacroImplicits._
def isMg[T](t: T): T = {
t shouldBe a[MacroGenerated]
t
}

def mgConv[T](te: TupleEntry)(implicit conv: TupleConverter[T]): T = isMg(conv)(te)
def mgSet[T](t: T)(implicit set: TupleSetter[T]): TupleEntry = new TupleEntry(isMg(set)(t))

def shouldRoundTrip[T: IsCaseClass: TupleSetter: TupleConverter](t: T) {
t shouldBe mgConv(mgSet(t))
}

def shouldRoundTripOther[T: IsCaseClass: TupleSetter: TupleConverter](te: TupleEntry, t: T) {
val inter = mgConv(te)
inter shouldBe t
mgSet(inter) shouldBe te
}

def canExternalize(t: AnyRef) { Externalizer(t).javaWorks shouldBe true }

"MacroGenerated TupleSetter" should {
def doesJavaWork[T](implicit set: TupleSetter[T]) { canExternalize(isMg(set)) }
"be serializable for case class A" in { doesJavaWork[SampleClassA] }
"be serializable for case class B" in { doesJavaWork[SampleClassB] }
"be serializable for case class C" in { doesJavaWork[SampleClassC] }
}

"MacroGenerated TupleConverter" should {
def doesJavaWork[T](implicit conv: TupleConverter[T]) { canExternalize(isMg(conv)) }
"be serializable for case class A" in { doesJavaWork[SampleClassA] }
"be serializable for case class B" in { doesJavaWork[SampleClassB] }
"be serializable for case class C" in { doesJavaWork[SampleClassC] }
}

"MacroGenerated TupleSetter and TupleConverter" should {
"round trip class -> tupleentry -> class" in {
shouldRoundTrip(SampleClassA(100, "onehundred"))
shouldRoundTrip(SampleClassB(SampleClassA(100, "onehundred"), SampleClassA(-1, "zero"), "what"))
val a = SampleClassA(73, "hrm")
val b = SampleClassB(a, a, "hrm")
shouldRoundTrip(b)
shouldRoundTrip(SampleClassC(a, b, SampleClassA(123980, "hey"), SampleClassB(a, SampleClassA(-1, "zero"), "zoo"), b))
}

"Case Class should form expected tuple" in {
val input = SampleClassC(SampleClassA(1, "asdf"),
SampleClassB(SampleClassA(2, "bcdf"), SampleClassA(5, "jkfs"), "wetew"),
SampleClassA(9, "xcmv"),
SampleClassB(SampleClassA(23, "ck"), SampleClassA(13, "dafk"), "xcv"),
SampleClassB(SampleClassA(34, "were"), SampleClassA(654, "power"), "adsfmx"))
val setter = implicitly[TupleSetter[SampleClassC]]
val tup = setter(input)
assert(tup.size == 19)
assert(tup.get(0) === 1)
assert(tup.get(18) === "adsfmx")
}

"round trip tupleentry -> class -> tupleEntry" in {
val a_tup = CTuple.size(2)
a_tup.setInteger(0, 100)
a_tup.setString(1, "onehundred")
val a_te = new TupleEntry(a_tup)
val a = SampleClassA(100, "onehundred")
shouldRoundTripOther(a_te, a)

val b_tup = CTuple.size(5)
b_tup.setInteger(0, 100)
b_tup.setString(1, "onehundred")
b_tup.setInteger(2, 100)
b_tup.setString(3, "onehundred")
b_tup.setString(4, "what")
val b_te = new TupleEntry(b_tup)
val b = SampleClassB(a, a, "what")
shouldRoundTripOther(b_te, b)

val c_tup = CTuple.size(19)
c_tup.setInteger(0, 100)
c_tup.setString(1, "onehundred")

c_tup.setInteger(2, 100)
c_tup.setString(3, "onehundred")
c_tup.setInteger(4, 100)
c_tup.setString(5, "onehundred")
c_tup.setString(6, "what")

c_tup.setInteger(7, 100)
c_tup.setString(8, "onehundred")

c_tup.setInteger(9, 100)
c_tup.setString(10, "onehundred")
c_tup.setInteger(11, 100)
c_tup.setString(12, "onehundred")
c_tup.setString(13, "what")

c_tup.setInteger(14, 100)
c_tup.setString(15, "onehundred")
c_tup.setInteger(16, 100)
c_tup.setString(17, "onehundred")
c_tup.setString(18, "what")

val c_te = new TupleEntry(c_tup)
val c = SampleClassC(a, b, a, b, b)
shouldRoundTripOther(c_te, c)
}
}
}

0 comments on commit e9b4f44

Please sign in to comment.