Skip to content

Commit

Permalink
Various improvements for Scala 3 macro to match Scala 2 implementation (
Browse files Browse the repository at this point in the history
#148)

These issues were found while porting Mill to Scala 3.
  • Loading branch information
bishabosha authored Aug 14, 2024
1 parent 4100113 commit 37d64b8
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 31 deletions.
2 changes: 1 addition & 1 deletion example/optseq/src/Main.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package example.optseq
import mainargs.{main, arg, ParserForMethods, ArgReader}
import mainargs.{main, arg, ParserForMethods, TokensReader}

object Main {
@main
Expand Down
125 changes: 96 additions & 29 deletions mainargs/src-3/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,34 @@ object Macros {
val companionModuleType = typeSymbolOfB.companionModule.tree.asInstanceOf[ValDef].tpt.tpe.asType
val companionModuleExpr = Ident(companionModule).asExpr
val mainAnnotationInstance = typeSymbolOfB.getAnnotation(mainAnnotation).getOrElse {
report.throwError(
s"cannot find @main annotation on ${companionModule.name}",
typeSymbolOfB.pos.get
'{new mainargs.main()}.asTerm // construct a default if not found.
}
val ctor = typeSymbolOfB.primaryConstructor
val ctorParams = ctor.paramSymss.flatten
// try to match the apply method with the constructor parameters, this is a good heuristic
// for if the apply method is overloaded.
val annotatedMethod = typeSymbolOfB.companionModule.memberMethod("apply").filter(p =>
p.paramSymss.flatten.corresponds(ctorParams) { (p1, p2) =>
p1.name == p2.name
}
).headOption.getOrElse {
report.errorAndAbort(
s"Cannot find apply method in companion object of ${typeReprOfB.show}",
typeSymbolOfB.companionModule.pos.getOrElse(Position.ofMacroExpansion)
)
}
val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head
companionModuleType match
case '[bCompanion] =>
val mainData = createMainData[B, Any](
val mainData = createMainData[B, bCompanion](
annotatedMethod,
mainAnnotationInstance,
// Somehow the `apply` method parameter annotations don't end up on
// the `apply` method parameters, but end up in the `<init>` method
// parameters, so use those for getting the annotations instead
TypeRepr.of[B].typeSymbol.primaryConstructor.paramSymss
)
'{ new ParserForClass[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) }
val erasedMainData = '{$mainData.asInstanceOf[MainData[B, Any]]}
'{ new ParserForClass[B]($erasedMainData, () => ${ Ident(companionModule).asExpr }) }
}

def createMainData[T: Type, B: Type](using Quotes)
Expand All @@ -57,41 +68,84 @@ object Macros {
createMainData[T, B](method, mainAnnotation, method.paramSymss)
}

private object VarargTypeRepr {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[quotes.reflect.TypeRepr] = {
import quotes.reflect.*
tpe match {
case AnnotatedType(AppliedType(_, Seq(arg)), x)
if x.tpe =:= defn.RepeatedAnnot.typeRef => Some(arg)
case _ => None
}
}
}

private object AsType {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Some[Type[?]] = {
Some(tpe.asType)
}
}

def createMainData[T: Type, B: Type](using Quotes)
(method: quotes.reflect.Symbol,
mainAnnotation: quotes.reflect.Term,
annotatedParamsLists: List[List[quotes.reflect.Symbol]]): Expr[MainData[T, B]] = {

import quotes.reflect.*
val params = method.paramSymss.headOption.getOrElse(report.throwError("Multiple parameter lists not supported"))
val defaultParams = getDefaultParams(method)
val defaultParams = if (params.exists(_.flags.is(Flags.HasDefault))) getDefaultParams(method) else Map.empty
val argSigsExprs = params.zip(annotatedParamsLists.flatten).map { paramAndAnnotParam =>
val param = paramAndAnnotParam._1
val annotParam = paramAndAnnotParam._2
val paramTree = param.tree.asInstanceOf[ValDef]
val paramTpe = paramTree.tpt.tpe
val readerTpe = paramTpe match {
case VarargTypeRepr(AsType('[t])) => TypeRepr.of[Leftover[t]]
case _ => paramTpe
}
val arg = annotParam.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
val paramType = paramTpe.asType
paramType match
readerTpe.asType match {
case '[t] =>
def applyAndCast(f: Expr[Any] => Expr[Any], arg: Expr[B]): Expr[t] = {
f(arg) match {
case '{ $v: `t` } => v
case expr => {
// this case will be activated when the found default parameter is not of type `t`
val recoveredType =
try
expr.asExprOf[t]
catch
case err: Exception =>
report.errorAndAbort(
s"""Failed to convert default value for parameter ${param.name},
|expected type: ${paramTpe.show},
|but default value ${expr.show} is of type: ${expr.asTerm.tpe.widen.show}
|while converting type caught an exception with message: ${err.getMessage}
|There might be a bug in mainargs.""".stripMargin,
param.pos.getOrElse(Position.ofMacroExpansion)
)
recoveredType
}
}
}
val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match {
case Some('{ $v: `t`}) => '{ Some(((_: B) => $v)) }
case Some(f) => '{ Some((b: B) => ${ applyAndCast(f, 'b) }) }
case None => '{ None }
}
val tokensReader = Expr.summon[mainargs.TokensReader[t]].getOrElse {
report.throwError(
s"No mainargs.ArgReader found for parameter ${param.name}",
param.pos.get
report.errorAndAbort(
s"No mainargs.TokensReader[${Type.show[t]}] found for parameter ${param.name} of method ${method.name} in ${method.owner.fullName}",
method.pos.getOrElse(Position.ofMacroExpansion)
)
}
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ tokensReader })) }
}
}
val argSigs = Expr.ofList(argSigsExprs)

val invokeRaw: Expr[(B, Seq[Any]) => T] = {

def callOf(methodOwner: Expr[Any], args: Expr[Seq[Any]]) =
call(methodOwner, method, '{ Seq($args) }).asExprOf[T]
call(methodOwner, method, args).asExprOf[T]

'{ (b: B, params: Seq[Any]) => ${ callOf('b, 'params) } }
}
Expand Down Expand Up @@ -120,37 +174,50 @@ object Macros {
private def call(using Quotes)(
methodOwner: Expr[Any],
method: quotes.reflect.Symbol,
argss: Expr[Seq[Seq[Any]]]
args: Expr[Seq[Any]]
): Expr[_] = {
// Copy pasted from Cask.
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
import quotes.reflect._
val paramss = method.paramSymss

if (paramss.isEmpty) {
report.throwError("At least one parameter list must be declared.", method.pos.get)
report.errorAndAbort("At least one parameter list must be declared.", method.pos.get)
}

val accesses: List[List[Term]] = for (i <- paramss.indices.toList) yield {
for (j <- paramss(i).indices.toList) yield {
val tpe = paramss(i)(j).tree.asInstanceOf[ValDef].tpt.tpe
tpe.asType match
case '[t] => '{ $argss(${Expr(i)})(${Expr(j)}).asInstanceOf[t] }.asTerm
}
if (paramss.sizeIs > 1) {
report.errorAndAbort("Multiple parameter lists are not supported.", method.pos.get)
}
val params = paramss.head

val methodType = methodOwner.asTerm.tpe.memberType(method)

def accesses(ref: Expr[Seq[Any]]): List[Term] =
for (i <- params.indices.toList) yield {
val param = params(i)
val tpe = methodType.memberType(param)
val untypedRef = '{ $ref(${Expr(i)}) }
tpe match {
case VarargTypeRepr(AsType('[t])) =>
Typed(
'{ $untypedRef.asInstanceOf[Leftover[t]].value }.asTerm,
Inferred(AppliedType(defn.RepeatedParamClass.typeRef, List(TypeRepr.of[t])))
)
case _ => tpe.asType match
case '[t] => '{ $untypedRef.asInstanceOf[t] }.asTerm
}
}

methodOwner.asTerm.select(method).appliedToArgss(accesses).asExpr
methodOwner.asTerm.select(method).appliedToArgs(accesses(args)).asExpr
}


/** Lookup default values for a method's parameters. */
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any]] = {
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any] => Expr[Any]] = {
// Copy pasted from Cask.
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L38
import quotes.reflect._

val params = method.paramSymss.flatten
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any]]
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any] => Expr[Any]]

val Name = (method.name + """\$default\$(\d+)""").r
val InitName = """\$lessinit\$greater\$default\$(\d+)""".r
Expand All @@ -159,13 +226,13 @@ object Macros {

idents.foreach{
case deff @ DefDef(Name(idx), _, _, _) =>
val expr = Ref(deff.symbol).asExpr
val expr = (owner: Expr[Any]) => Select(owner.asTerm, deff.symbol).asExpr
defaults += (params(idx.toInt - 1) -> expr)

// The `apply` method re-uses the default param factory methods from `<init>`,
// so make sure to check if those exist too
case deff @ DefDef(InitName(idx), _, _, _) if method.name == "apply" =>
val expr = Ref(deff.symbol).asExpr
val expr = (owner: Expr[Any]) => Select(owner.asTerm, deff.symbol).asExpr
defaults += (params(idx.toInt - 1) -> expr)

case _ =>
Expand Down
69 changes: 69 additions & 0 deletions mainargs/test/src/ClassTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,56 @@ object ClassTests extends TestSuite {
@main
case class Qux(moo: String, b: Bar)

case class Cli(@arg(short = 'd') debug: Flag)

@main
class Compat(
@arg(short = 'h') val home: String,
@arg(short = 's') val silent: Flag,
val leftoverArgs: Leftover[String]
) {
override def equals(obj: Any): Boolean =
obj match {
case c: Compat =>
home == c.home && silent == c.silent && leftoverArgs == c.leftoverArgs
case _ => false
}
}
object Compat {
def apply(
home: String = "/home",
silent: Flag = Flag(),
leftoverArgs: Leftover[String] = Leftover()
) = new Compat(home, silent, leftoverArgs)

@deprecated("bin-compat shim", "0.1.0")
private[mainargs] def apply(
home: String,
silent: Flag,
noDefaultPredef: Flag,
leftoverArgs: Leftover[String]
) = new Compat(home, silent, leftoverArgs)
}

implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
implicit val barParser: ParserForClass[Bar] = ParserForClass[Bar]
implicit val quxParser: ParserForClass[Qux] = ParserForClass[Qux]
implicit val cliParser: ParserForClass[Cli] = ParserForClass[Cli]
implicit val compatParser: ParserForClass[Compat] = ParserForClass[Compat]

class PathWrap {
@main
case class Foo(x: Int = 23, y: Int = 47)

object Main {
@main
def run(bar: Bar, bool: Boolean = false) = {
s"${bar.w.value} ${bar.f.x} ${bar.f.y} ${bar.zzzz} $bool"
}
}

implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
}

object Main {
@main
Expand Down Expand Up @@ -161,5 +208,27 @@ object ClassTests extends TestSuite {
Seq("-x", "1", "-y", "2", "-z", "hello")
) ==> "false 1 2 hello false"
}
test("mill-compat") {
test("apply-overload-class") {
compatParser.constructOrThrow(Seq("foo")) ==> Compat(
home = "/home",
silent = Flag(false),
leftoverArgs = Leftover("foo")
)
}
test("no-main-on-class") {
cliParser.constructOrThrow(Seq("-d")) ==> Cli(Flag(true))
}
test("path-dependent-default") {
val p = new PathWrap
p.fooParser.constructOrThrow(Seq()) ==> p.Foo(23, 47)
}
test("path-dependent-default-method") {
val p = new PathWrap
ParserForMethods(p.Main).runOrThrow(
Seq("-x", "1", "-y", "2", "-z", "hello")
) ==> "false 1 2 hello false"
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ object VarargsOldTests extends VarargsBaseTests {

@main
def mixedVariadic(@arg(short = 'f') first: Int, args: String*) =
first + args.mkString
first.toString + args.mkString
}

val check = new Checker(ParserForMethods(Base), allowPositional = true)
Expand Down

0 comments on commit 37d64b8

Please sign in to comment.