Skip to content

Commit

Permalink
Ability to relocate/shade in assembly
Browse files Browse the repository at this point in the history
  • Loading branch information
joan38 committed Aug 19, 2020
1 parent 5e164d3 commit 8ae6772
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 69 deletions.
4 changes: 3 additions & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ object Deps {
val utest = ivy"com.lihaoyi::utest:0.7.4"
val zinc = ivy"org.scala-sbt::zinc:1.4.0-M1"
val bsp = ivy"ch.epfl.scala:bsp4j:2.0.0-M4"
val jarjarabrams = ivy"com.eed3si9n.jarjarabrams::jarjar-abrams-core:0.2.0"
}

trait MillPublishModule extends PublishModule{
Expand Down Expand Up @@ -167,7 +168,8 @@ object main extends MillModule {
// Necessary so we can share the JNA classes throughout the build process
Deps.jna,
Deps.jnaPlatform,
Deps.coursier
Deps.coursier,
Deps.jarjarabrams
)

def generatedSources = T {
Expand Down
78 changes: 34 additions & 44 deletions main/src/modules/Assembly.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package mill.modules

import com.eed3si9n.jarjarabrams.{ShadePattern, ShadeRule, Shader}
import java.io.InputStream
import java.util.jar.JarFile
import java.util.regex.Pattern

import geny.Generator
import mill.Agg

import os.Generator
import scala.collection.JavaConverters._

object Assembly {
Expand All @@ -32,13 +31,15 @@ object Assembly {

case class Exclude(path: String) extends Rule

case class Relocate(from: String, to: String) extends Rule

object ExcludePattern {
def apply(pattern: String): ExcludePattern = ExcludePattern(Pattern.compile(pattern))
}
case class ExcludePattern(pattern: Pattern) extends Rule
}

def groupAssemblyEntries(inputPaths: Agg[os.Path], assemblyRules: Seq[Assembly.Rule]): Map[String, GroupedEntry] = {
def groupAssemblyEntries(mappings: Generator[(String, InputStream)], assemblyRules: Seq[Assembly.Rule]): Map[String, GroupedEntry] = {
val rulesMap = assemblyRules.collect {
case r@Rule.Append(path, _) => path -> r
case r@Rule.Exclude(path) => path -> r
Expand All @@ -52,23 +53,19 @@ object Assembly {
case Rule.ExcludePattern(pattern) => pattern.asPredicate().test(_)
}

classpathIterator(inputPaths).foldLeft(Map.empty[String, GroupedEntry]) {
case (entries, entry) =>
val mapping = entry.mapping

mappings.foldLeft(Map.empty[String, GroupedEntry]) {
case (entries, (mapping, entry)) =>
rulesMap.get(mapping) match {
case Some(_: Assembly.Rule.Exclude) =>
entries
case Some(a: Assembly.Rule.Append) =>
val newEntry = entries.getOrElse(mapping, AppendEntry(Nil, a.separator)).append(entry)
val newEntry = entries.getOrElse(mapping, AppendEntry(Seq.empty, a.separator)).append(entry)
entries + (mapping -> newEntry)

case _ if excludePatterns.exists(_(mapping)) =>
entries
case _ if appendPatterns.exists(_(mapping)) =>
val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry)
entries + (mapping -> newEntry)

case _ if !entries.contains(mapping) =>
entries + (mapping -> WriteOnceEntry(entry))
case _ =>
Expand All @@ -77,52 +74,45 @@ object Assembly {
}
}

private def classpathIterator(inputPaths: Agg[os.Path]): Generator[AssemblyEntry] = {
def loadShadedClasspath(
inputPaths: Agg[os.Path],
assemblyRules: Seq[Assembly.Rule]
): Generator[(String, InputStream)] = {
val shadeRules = assemblyRules.collect {
case Rule.Relocate(from, to) => ShadePattern.Rename(List(from -> to)).inAll
}

Generator.from(inputPaths)
.filter(os.exists)
.flatMap {
p =>
if (os.isFile(p)) {
val jf = new JarFile(p.toIO)
Generator.from(
for(entry <- jf.entries().asScala if !entry.isDirectory)
yield JarFileEntry(entry.getName, () => jf.getInputStream(entry))
)
}
else {
os.walk.stream(p)
.filter(os.isFile)
.map(sub => PathEntry(sub.relativeTo(p).toString, sub))
}
.flatMap { path =>
val shader = Shader.inputStreamShader(shadeRules, verbose = false)
(if (os.isFile(path)) {
val jarFile = new JarFile(path.toIO)
Generator.from(jarFile.entries().asScala.filterNot(_.isDirectory))
.flatMap(entry => shader(jarFile.getInputStream(entry), entry.getName))
}
else {
os.walk
.stream(path)
.filter(os.isFile)
.flatMap(subPath => shader(os.read.inputStream(subPath), subPath.relativeTo(path).toString))
}).map(_.swap)
}
}
}

private[modules] sealed trait GroupedEntry {
def append(entry: AssemblyEntry): GroupedEntry
def append(entry: InputStream): GroupedEntry
}

private[modules] object AppendEntry {
val empty: AppendEntry = AppendEntry(Nil, Assembly.defaultSeparator)
}

private[modules] case class AppendEntry(entries: List[AssemblyEntry], separator: String) extends GroupedEntry {
def append(entry: AssemblyEntry): GroupedEntry = copy(entries = entry :: this.entries)
}

private[modules] case class WriteOnceEntry(entry: AssemblyEntry) extends GroupedEntry {
def append(entry: AssemblyEntry): GroupedEntry = this
}

private[this] sealed trait AssemblyEntry {
def mapping: String
def inputStream: InputStream
}

private[this] case class PathEntry(mapping: String, path: os.Path) extends AssemblyEntry {
def inputStream: InputStream = os.read.inputStream(path)
private[modules] case class AppendEntry(entries: Seq[InputStream], separator: String) extends GroupedEntry {
def append(entry: InputStream): GroupedEntry = copy(entries = entry +: entries)
}

private[this] case class JarFileEntry(mapping: String, getIs: () => InputStream) extends AssemblyEntry {
def inputStream: InputStream = getIs()
private[modules] case class WriteOnceEntry(entry: InputStream) extends GroupedEntry {
def append(entry: InputStream): GroupedEntry = this
}
41 changes: 17 additions & 24 deletions main/src/modules/Jvm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ import java.nio.file.{FileSystems, Files, StandardOpenOption}
import java.nio.file.attribute.PosixFilePermission
import java.util.Collections
import java.util.jar.{Attributes, JarEntry, JarFile, JarOutputStream, Manifest}

import coursier.{Dependency, Fetch, Repository, Resolution}
import coursier.{Dependency, Repository, Resolution}
import coursier.util.{Gather, Task}
import geny.Generator
import mill.main.client.InputPumper
import mill.eval.{PathRef, Result}
import mill.util.Ctx
import mill.api.IO
import mill.api.Loose.Agg

import scala.collection.mutable
import scala.collection.JavaConverters._
import upickle.default.{macroRW, ReadWriter => RW}
import upickle.default.{ReadWriter => RW}

object Jvm {
/**
Expand Down Expand Up @@ -293,27 +290,23 @@ object Jvm {
manifest.build.write(manifestOut)
manifestOut.close()

Assembly.groupAssemblyEntries(inputPaths, assemblyRules).view
.foreach {
case (mapping, AppendEntry(entries, separator)) =>
val path = zipFs.getPath(mapping).toAbsolutePath
val separated =
if (entries.isEmpty) Nil
else
entries.head +: entries.tail.flatMap { e =>
List(JarFileEntry(e.mapping, () => new ByteArrayInputStream(separator.getBytes)), e)
}
val concatenated = new SequenceInputStream(
Collections.enumeration(separated.map(_.inputStream).asJava))
writeEntry(path, concatenated, append = true)
case (mapping, WriteOnceEntry(entry)) =>
val path = zipFs.getPath(mapping).toAbsolutePath
writeEntry(path, entry.inputStream, append = false)
}

val mappings = Assembly.loadShadedClasspath(inputPaths, assemblyRules)
Assembly.groupAssemblyEntries(mappings, assemblyRules).foreach {
case (mapping, AppendEntry(entries, separator)) =>
val path = zipFs.getPath(mapping).toAbsolutePath
val separated =
if (entries.isEmpty) Nil
else entries.head +: entries.tail.flatMap(e => Seq(new ByteArrayInputStream(separator.getBytes), e))
val concatenated = new SequenceInputStream(Collections.enumeration(separated.asJava))
writeEntry(path, concatenated, append = true)
case (mapping, WriteOnceEntry(entry)) =>
val path = zipFs.getPath(mapping).toAbsolutePath

writeEntry(path, entry, append = false)
}
zipFs.close()
val output = ctx.dest / "out.jar"

val output = ctx.dest / "out.jar"
// Prepend shell script and make it executable
if (prependShellScript.isEmpty) os.move(tmp, output)
else{
Expand Down
100 changes: 100 additions & 0 deletions main/src/modules/Shader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package com.eed3si9n.jarjarabrams

import java.io.{ByteArrayInputStream, InputStream}
import java.nio.file.{ Files, Path, StandardOpenOption }
import org.pantsbuild.jarjar.{ JJProcessor, _ }
import org.pantsbuild.jarjar.util.EntryStruct
import com.eed3si9n.jarjarabrams.Utils.readAllBytes

object Shader {
def shadeDirectory(
rules: Seq[ShadeRule],
dir: Path,
mappings: Iterable[(Path, String)],
verbose: Boolean
): Unit = {
val inputStreams = mappings.filter(x => !Files.isDirectory(x._1)).map(x => Files.newInputStream(x._1) -> x._2)
val newMappings = shadeInputStreams(rules, inputStreams, verbose)
mappings.filterNot(_._1.toFile.isDirectory).foreach(f => Files.delete(f._1))
newMappings.foreach { case (inputStream, mapping) =>
val out = dir.resolve(mapping)
if (!Files.exists(out.getParent)) Files.createDirectories(out.getParent)
Files.write(out, readAllBytes(inputStream), StandardOpenOption.CREATE)
}
}

def shadeInputStreams(
rules: Seq[ShadeRule],
mappings: Iterable[(InputStream, String)],
verbose: Boolean
): Iterable[(InputStream, String)] = {
val shader = inputStreamShader(rules, verbose)
mappings.flatMap(f => shader(f._1, f._2))
}

def inputStreamShader(
rules: Seq[ShadeRule],
verbose: Boolean
): (InputStream, String) => Option[(InputStream, String)] = {
val jjrules = rules.flatMap { r =>
r.shadePattern match {
case ShadePattern.Rename(patterns) =>
patterns.map { case (from, to) =>
val jrule = new Rule()
jrule.setPattern(from)
jrule.setResult(to)
jrule
}
case ShadePattern.Zap(patterns) =>
patterns.map { pattern =>
val jrule = new Zap()
jrule.setPattern(pattern)
jrule
}
case ShadePattern.Keep(patterns) =>
patterns.map { pattern =>
val jrule = new Keep()
jrule.setPattern(pattern)
jrule
}
case _ => Nil
}
}

val proc = new JJProcessor(jjrules, verbose, true, null)

{ (inputStream, mapping) =>
/*
jarjar MisplacedClassProcessor class transforms byte[] to a class using org.objectweb.asm.ClassReader.getClassName
which always translates class names containing '.' into '/', regardless of OS platform.
We need to transform any windows file paths in order for jarjar to match them properly and not omit them.
*/
val sanitizedMapping = if (mapping.contains('\\')) mapping.replace('\\', '/') else mapping
val entry = new EntryStruct
entry.data = readAllBytes(inputStream)
entry.name = sanitizedMapping
entry.time = -1
entry.skipTransform = false
val shadedInputStream =
if (proc.process(entry)) Some(new ByteArrayInputStream(entry.data) {
override def close(): Unit = inputStream.close()
} -> entry.name)
else None
val excludes = proc.getExcludes
shadedInputStream.filterNot(a => excludes.contains(a._2))
}
}
}

sealed trait ShadePattern {
def inAll: ShadeRule = ShadeRule(this, Vector(ShadeTarget.inAll))
def inProject: ShadeRule = ShadeRule(this, Vector(ShadeTarget.inProject))
def inModuleCoordinates(moduleId: ModuleCoordinate*): ShadeRule =
ShadeRule(this, moduleId.toVector map ShadeTarget.inModuleCoordinate)
}

object ShadePattern {
case class Rename(patterns: List[(String, String)]) extends ShadePattern
case class Zap(patterns: List[String]) extends ShadePattern
case class Keep(patterns: List[String]) extends ShadePattern
}

0 comments on commit 8ae6772

Please sign in to comment.