Skip to content

Commit

Permalink
adding support for customizable constructor annotations, codec and "i…
Browse files Browse the repository at this point in the history
…nclusive dot" (#481)
  • Loading branch information
Iurii Malchenko authored and lihaoyi committed Jan 4, 2019
1 parent 54b797c commit eb497a8
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 47 deletions.
6 changes: 3 additions & 3 deletions contrib/twirllib/src/TwirlModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ trait TwirlModule extends mill.Module {

def twirlAdditionalImports: Seq[String] = Nil

private def twirlConstructorAnnotations: Seq[String] = Nil
def twirlConstructorAnnotations: Seq[String] = Nil

private def twirlCodec: Codec = Codec(Properties.sourceEncoding)
def twirlCodec: Codec = Codec(Properties.sourceEncoding)

private def twirlInclusiveDot: Boolean = false
def twirlInclusiveDot: Boolean = false

def compileTwirl: T[mill.scalalib.api.CompilationResult] = T.persistent {
TwirlWorkerApi.twirlWorker
Expand Down
79 changes: 49 additions & 30 deletions contrib/twirllib/src/TwirlWorker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package twirllib
import java.io.File
import java.lang.reflect.Method
import java.net.URLClassLoader
import java.nio.charset.Charset

import mill.api.PathRef
import mill.scalalib.api.CompilationResult
Expand All @@ -22,24 +23,29 @@ class TwirlWorker {

// Switched to using the java api because of the hack-ish thing going on later.
//
// * we'll need to construct a collection of additional imports
// * it will need to consider the defaults
// * and add the user-provided additional imports
// * we'll need to construct a collection of additional imports (will need to also include the defaults and add the user-provided additional imports)
// * we'll need to construct a collection of constructor annotations// *
// * the default collection in scala api is a Seq[String]
// * but it is defined in a different classloader (namely in cl)
// * so we can not construct our own Seq and pass it to the method - it will be from our classloader, and not compatible
// * the java api has a Collection as the type for this param, for which it is much more doable to append things to it using reflection
// * the java api uses java collections, manipulating which using reflection is much simpler
//
// NOTE: I tried creating the cl classloader passing the current classloader as the parent:
// NOTE: When creating the cl classloader with passing the current classloader as the parent:
// val cl = new URLClassLoader(twirlClasspath.map(_.toIO.toURI.toURL).toArray, getClass.getClassLoader)
// in that case it was possible to cast the default to a Seq[String], construct our own Seq[String], and pass it to the method invoke- it was compatible.
// And the tests passed. But when run in a different mill project, I was getting exceptions like this:
// it is possible to cast the default to a Seq[String], construct our own Seq[String], and pass it to the method invoke -
// classe will be compatible (the tests passed).
// But when run in an actual mill project with this module enabled, there were exceptions like this:
// scala.reflect.internal.MissingRequirementError: object scala in compiler mirror not found.

val twirlCompilerClass = cl.loadClass("play.japi.twirl.compiler.TwirlCompiler")

// this one is only to get the codec: Codec parameter default value
val twirlScalaCompilerClass = cl.loadClass("play.twirl.compiler.TwirlCompiler")
val codecClass = cl.loadClass("scala.io.Codec")
val charsetClass = cl.loadClass("java.nio.charset.Charset")
val arrayListClass = cl.loadClass("java.util.ArrayList")
val hashSetClass = cl.loadClass("java.util.HashSet")

val codecApplyMethod = codecClass.getMethod("apply", charsetClass)
val charsetForNameMethod = charsetClass.getMethod("forName", classOf[java.lang.String])

val compileMethod = twirlCompilerClass.getMethod("compile",
classOf[java.io.File],
Expand All @@ -51,11 +57,9 @@ class TwirlWorker {
cl.loadClass("scala.io.Codec"),
classOf[Boolean])

val arrayListClass = cl.loadClass("java.util.ArrayList")
val hashSetClass = cl.loadClass("java.util.HashSet")
val defaultImportsMethod = twirlCompilerClass.getField("DEFAULT_IMPORTS")

val defaultAdditionalImportsMethod = twirlCompilerClass.getField("DEFAULT_IMPORTS")
val defaultCodecMethod = twirlScalaCompilerClass.getMethod("compile$default$7")
val hashSetConstructor = hashSetClass.getConstructor(cl.loadClass("java.util.Collection"))

val instance = new TwirlWorkerApi {
override def compileTwirl(source: File,
Expand All @@ -66,27 +70,42 @@ class TwirlWorker {
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean) {
val defaultAdditionalImports = defaultAdditionalImportsMethod.get(null) // unmodifiable collection
// copying it into a modifiable hash set and adding all additional imports
val allAdditionalImports =
hashSetClass
.getConstructor(cl.loadClass("java.util.Collection"))
.newInstance(defaultAdditionalImports)
.asInstanceOf[Object]
val hashSetAddMethod =
allAdditionalImports
.getClass
.getMethod("add", classOf[Object])
additionalImports.foreach(hashSetAddMethod.invoke(allAdditionalImports, _))

// val defaultImports = play.japi.twirl.compiler.TwirlCompiler.DEFAULT_IMPORTS()
// val twirlAdditionalImports = new HashSet(defaultImports)
// additionalImports.foreach(twirlAdditionalImports.add)
val defaultImports = defaultImportsMethod.get(null) // unmodifiable collection
val twirlAdditionalImports = hashSetConstructor.newInstance(defaultImports).asInstanceOf[Object]
val hashSetAddMethod = twirlAdditionalImports.getClass.getMethod("add", classOf[Object])
additionalImports.foreach(hashSetAddMethod.invoke(twirlAdditionalImports, _))

// Codec.apply(Charset.forName(codec.charSet.name()))
val twirlCodec = codecApplyMethod.invoke(null, charsetForNameMethod.invoke(null, codec.charSet.name()))

// val twirlConstructorAnnotations = new ArrayList()
// constructorAnnotations.foreach(twirlConstructorAnnotations.add)
val twirlConstructorAnnotations = arrayListClass.newInstance().asInstanceOf[Object]
val arrayListAddMethod = twirlConstructorAnnotations.getClass.getMethod("add", classOf[Object])
constructorAnnotations.foreach(arrayListAddMethod.invoke(twirlConstructorAnnotations, _))

// JavaAPI
// public static Optional<File> compile(
// File source,
// File sourceDirectory,
// File generatedDirectory,
// String formatterType,
// Collection<String> additionalImports,
// List<String> constructorAnnotations,
// Codec codec,
// boolean inclusiveDot
// )
val o = compileMethod.invoke(null, source,
sourceDirectory,
generatedDirectory,
formatterType,
allAdditionalImports,
arrayListClass.newInstance().asInstanceOf[Object], // empty list seems to be the default
defaultCodecMethod.invoke(null),
Boolean.box(false)
twirlAdditionalImports,
twirlConstructorAnnotations,
twirlCodec,
Boolean.box(inclusiveDot)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@this(title: String)
@wrapper {
<html>
<body>
<h1>@title</h1>
</body>
</html>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@(content: Html)

@defining("test") { className =>
<div class="@className">@content</div>
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@(title: String)
@this(title: String)
<html>
<body>
<h1>@title</h1>
Expand Down
74 changes: 63 additions & 11 deletions contrib/twirllib/test/src/HelloWorldTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,40 @@ object HelloWorldTests extends TestSuite {
}

trait HelloWorldModule extends mill.twirllib.TwirlModule {
def twirlVersion = "1.0.0"
override def twirlAdditionalImports: Seq[String] = additionalImports

def twirlVersion = "1.3.15"

}

object HelloWorld extends HelloBase {

object core extends HelloWorldModule {
override def twirlVersion = "1.3.15"
override def twirlAdditionalImports: Seq[String] = testAdditionalImports
override def twirlConstructorAnnotations: Seq[String] = testConstructorAnnotations
}

}

val resourcePath: os.Path = os.pwd / 'contrib / 'twirllib / 'test / 'resources / "hello-world"
object HelloWorldWithInclusiveDot extends HelloBase {

object core extends HelloWorldModule {
override def twirlInclusiveDot: Boolean = true
}

}

def workspaceTest[T](
m: TestUtil.BaseModule,
resourcePath: os.Path = resourcePath
resourcePathSuffix: String
)(t: TestEvaluator => T)(implicit tp: TestPath): T = {
val eval = new TestEvaluator(m)
os.remove.all(m.millSourcePath)
os.remove.all(eval.outPath)
os.makeDir.all(m.millSourcePath / os.up)
os.copy(resourcePath, m.millSourcePath)
os.copy(
os.pwd / 'contrib / 'twirllib / 'test / 'resources / resourcePathSuffix,
m.millSourcePath
)
t(eval)
}

Expand All @@ -52,15 +64,20 @@ object HelloWorldTests extends TestSuite {
"import _root_.play.twirl.api.Xml"
)

def additionalImports: Seq[String] = Seq(
def testAdditionalImports: Seq[String] = Seq(
"mill.twirl.test.AdditionalImport1._",
"mill.twirl.test.AdditionalImport2._"
)

def testConstructorAnnotations = Seq(
"@org.springframework.stereotype.Component()",
"@something.else.Thing()"
)

def tests: Tests = Tests {
'twirlVersion - {

'fromBuild - workspaceTest(HelloWorld) { eval =>
'fromBuild - workspaceTest(HelloWorld, "hello-world") { eval =>
val Right((result, evalCount)) =
eval.apply(HelloWorld.core.twirlVersion)

Expand All @@ -70,7 +87,7 @@ object HelloWorldTests extends TestSuite {
)
}
}
'compileTwirl - workspaceTest(HelloWorld) { eval =>
'compileTwirl - workspaceTest(HelloWorld, "hello-world") { eval =>
val Right((result, evalCount)) = eval.apply(HelloWorld.core.compileTwirl)

val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
Expand All @@ -86,8 +103,43 @@ object HelloWorldTests extends TestSuite {
evalCount > 0,
outputFiles.forall { p =>
val lines = os.read.lines(p).map(_.trim)
(expectedDefaultImports ++ additionalImports.map(s => s"import $s")).forall(lines.contains)
}
(expectedDefaultImports ++ testAdditionalImports.map(s => s"import $s")).forall(lines.contains)
},
outputFiles.filter(_.toString().contains("hello.template.scala")).forall { p =>
val lines = os.read.lines(p).map(_.trim)
val expectedClassDeclaration = s"class hello ${testConstructorAnnotations.mkString}"
lines.exists(_.startsWith(expectedClassDeclaration))
},

)

// don't recompile if nothing changed
val Right((_, unchangedEvalCount)) =
eval.apply(HelloWorld.core.compileTwirl)

assert(unchangedEvalCount == 0)
}
'compileTwirlInclusiveDot - workspaceTest(HelloWorldWithInclusiveDot, "hello-world-inclusive-dot") { eval =>
val Right((result, evalCount)) = eval.apply(HelloWorldWithInclusiveDot.core.compileTwirl)

val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
val expectedClassfiles = compileClassfiles.map( name =>
eval.outPath / 'core / 'compileTwirl / 'dest / 'html / name.toString().replace(".template.scala", "$$TwirlInclusiveDot.template.scala")
)

println(s"outputFiles: $outputFiles")

assert(
result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest,
outputFiles.nonEmpty,
outputFiles.forall(expectedClassfiles.contains),
outputFiles.size == 2,
evalCount > 0,
outputFiles.filter(_.toString().contains("hello.template.scala")).forall { p =>
val lines = os.read.lines(p).map(_.trim)
lines.exists(_.contains("$$TwirlInclusiveDot"))
},

)

// don't recompile if nothing changed
Expand Down
6 changes: 4 additions & 2 deletions docs/pages/9 - Contrib Modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ object app extends ScalaModule with TwirlModule {
#### Configuration options

* `def twirlVersion: T[String]` (mandatory) - the version of the twirl compiler to use, like "1.3.15"
* `def twirlAdditionalImports: Seq[String] = Nil` - the additional imports that will be added by twirl compiler to the top
of all templates
* `def twirlAdditionalImports: Seq[String] = Nil` - the additional imports that will be added by twirl compiler to the top of all templates
* `def twirlConstructorAnnotations: Seq[String] = Nil` - annotations added to the generated classes' constructors (note it only applies to templates with `@this(...)` constructors)
* `def twirlCodec = Codec(Properties.sourceEncoding)` - the codec used to generate the files (the default is the same sbt plugin uses)
* `def twirlInclusiveDot: Boolean = false`

#### Details

Expand Down

0 comments on commit eb497a8

Please sign in to comment.