Skip to content

Commit

Permalink
Simplify MainAnnotation interface
Browse files Browse the repository at this point in the history
Remove the `Command` class and place the `argGetter`, `varargsGetter`
and `run` methods directly in the `MainAnnotation` interface.

Now `command` pre-processes the arguments which clearly states which
strings will be used for each argument. This simplifies the implementation
of the `MainAnnotation` methods.
  • Loading branch information
nicolasstucki committed Apr 14, 2022
1 parent e852aa7 commit 813e059
Show file tree
Hide file tree
Showing 15 changed files with 511 additions and 462 deletions.
104 changes: 64 additions & 40 deletions compiler/src/dotty/tools/dotc/ast/MainProxies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,29 +143,31 @@ object MainProxies {
* */
* @myMain(80) def f(
* @myMain.Alias("myX") x: S,
* y: S,
* ys: T*
* ) = ...
*
* would be translated to something like
*
* final class f {
* static def main(args: Array[String]): Unit = {
* val cmd = new myMain(80).command(
* info = new CommandInfo(
* name = "f",
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
* parameters = Seq(
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
* new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq())
* )
* val annotation = new myMain(80)
* val info = new Info(
* name = "f",
* documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
* parameters = Seq(
* new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))),
* new scala.annotation.MainAnnotation.Parameter("y", "S", true, false, "", Seq()),
* new scala.annotation.MainAnnotation.Parameter("ys", "T", false, true, "all my params y", Seq())
* )
* args = args
* )
*
* val args0: () => S = cmd.argGetter[S](0, None)
* val args1: () => Seq[T] = cmd.varargGetter[T]
*
* cmd.run(() => f(args0(), args1()*))
* ),
* val command = annotation.command(info, args)
* if command.isDefined then
* val cmd = command.get
* val args0: () => S = annotation.argGetter[S](info.parameters(0), cmd(0), None)
* val args1: () => S = annotation.argGetter[S](info.parameters(1), mainArgs(1), Some(() => sum$default$1()))
* val args2: () => Seq[T] = annotation.varargGetter[T](info.parameters(2), cmd.drop(2))
* annotation.run(() => f(args0(), args1(), args2()*))
* }
* }
*/
Expand Down Expand Up @@ -229,7 +231,7 @@ object MainProxies {
*
* A ParamInfo has the following shape
* ```
* new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
* new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
* ```
*/
def parameterInfos(mt: MethodType): List[Tree] =
Expand All @@ -252,33 +254,34 @@ object MainProxies {
val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc)
.map(value => Literal(Constant(value)))

New(TypeTree(defn.MainAnnotationParameterInfo.typeRef), List(constructorArgs :+ paramAnnots))
New(TypeTree(defn.MainAnnotationParameter.typeRef), List(constructorArgs :+ paramAnnots))

end parameterInfos

/**
* Creates a list of references and definitions of arguments.
* The goal is to create the
* `val args0: () => S = cmd.argGetter[S](0, None)`
* `val args0: () => S = annotation.argGetter[S](0, cmd(0), None)`
* part of the code.
*/
def argValDefs(mt: MethodType): List[ValDef] =
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
val argName = nme.args ++ idx.toString
val isRepeated = formal.isRepeatedParam
val formalType = if isRepeated then formal.argTypes.head else formal
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
case None => ref(defn.NoneModule.termRef)
case Some(dvSym) =>
val value = unitToValue(ref(dvSym.termRef))
Apply(ref(defn.SomeClass.companionModule.termRef), value)
val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterName), TypeTree(formalType) :: Nil)
val argGetter =
if isRepeated then argGetter0
else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt))

ValDef(argName, TypeTree(), argGetter)
val argName = nme.args ++ idx.toString
val isRepeated = formal.isRepeatedParam
val formalType = if isRepeated then formal.argTypes.head else formal
val getterName = if isRepeated then nme.varargGetter else nme.argGetter
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
case None => ref(defn.NoneModule.termRef)
case Some(dvSym) =>
val value = unitToValue(ref(dvSym.termRef))
Apply(ref(defn.SomeClass.companionModule.termRef), value)
val argGetter0 = TypeApply(Select(Ident(nme.annotation), getterName), TypeTree(formalType) :: Nil)
val index = Literal(Constant(idx))
val paramInfo = Apply(Select(Ident(nme.info), nme.parameters), index)
val argGetter =
if isRepeated then Apply(argGetter0, List(paramInfo, Apply(Select(Ident(nme.cmd), nme.drop), List(index))))
else Apply(argGetter0, List(paramInfo, Apply(Ident(nme.cmd), List(index)), defaultValueGetterOpt))
ValDef(argName, TypeTree(), argGetter)
end argValDefs


Expand Down Expand Up @@ -318,18 +321,39 @@ object MainProxies {
val nameTree = Literal(Constant(mainFun.showName))
val docTree = Literal(Constant(documentation.mainDoc))
val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos)
New(TypeTree(defn.MainAnnotationCommandInfo.typeRef), List(List(nameTree, docTree, paramInfos)))
New(TypeTree(defn.MainAnnotationInfo.typeRef), List(List(nameTree, docTree, paramInfos)))

val cmd = ValDef(
nme.cmd,
val annotVal = ValDef(
nme.annotation,
TypeTree(),
instantiateAnnotation(mainAnnot)
)
val infoVal = ValDef(
nme.info,
TypeTree(),
cmdInfo
)
val command = ValDef(
nme.command,
TypeTree(),
Apply(
Select(instantiateAnnotation(mainAnnot), nme.command),
List(cmdInfo, Ident(nme.args))
Select(Ident(nme.annotation), nme.command),
List(Ident(nme.info), Ident(nme.args))
)
)
val run = Apply(Select(Ident(nme.cmd), nme.run), mainCall)
val body = Block(cmdInfo :: cmd :: args, run)
val argsVal = ValDef(
nme.cmd,
TypeTree(),
Select(Ident(nme.command), nme.get)
)
val run = Apply(Select(Ident(nme.annotation), nme.run), mainCall)
val body0 = If(
Select(Ident(nme.command), nme.isDefined),
Block(argsVal :: args, run),
EmptyTree
)
val body = Block(List(annotVal, infoVal, command), body0) // TODO add `if (cmd.nonEmpty)`

val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree)
.withFlags(Param)
/** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,8 @@ class Definitions {
@tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope")

@tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation")
@tu lazy val MainAnnotationCommandInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.CommandInfo")
@tu lazy val MainAnnotationParameterInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfo")
@tu lazy val MainAnnotationInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Info")
@tu lazy val MainAnnotationParameter: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Parameter")
@tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation")
@tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command")

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ object StdNames {
val ordinalDollar: N = "$ordinal"
val ordinalDollar_ : N = "_$ordinal"
val origin: N = "origin"
val parameters: N = "parameters"
val parts: N = "parts"
val postfixOps: N = "postfixOps"
val prefix : N = "prefix"
Expand Down
92 changes: 49 additions & 43 deletions docs/_docs/reference/experimental/main-annotation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,35 @@ When a users annotates a method with an annotation that extends `MainAnnotation`
* @param first Fist number to sum
* @param rest The rest of the numbers to sum
*/
@myMain def sum(first: Int, rest: Int*): Int = first + rest.sum
@myMain def sum(first: Int, second: Int = 0, rest: Int*): Int = first + second + rest.sum
```

```scala
object foo {
def main(args: Array[String]): Unit = {

val cmd = new myMain().command(
info = new CommandInfo(
name = "sum",
documentation = "Sum all the numbers",
parameters = Seq(
new ParameterInfo("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()),
new ParameterInfo("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq())
)
),
args = args
val mainAnnot = new myMain()
val info = new Info(
name = "foo.main",
documentation = "Sum all the numbers",
parameters = Seq(
new Parameter("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()),
new Parameter("second", "scala.Int", hasDefault=true, isVarargs=false, "", Seq()),
new Parameter("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq())
)
)
val args0 = cmd.argGetter[Int](0, None) // using a parser of Int
val args1 = cmd.varargGetter[Int] // using a parser of Int
cmd.run(() => sum(args0(), args1()*))
val mainArgsOpt = mainAnnot.command(info, args)
if mainArgsOpt.isDefined then
val mainArgs = mainArgsOpt.get
val args0 = mainAnnot.argGetter[Int](info.parameters(0), mainArgs(0), None) // using a parser of Int
val args1 = mainAnnot.argGetter[Int](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) // using a parser of Int
val args2 = mainAnnot.varargGetter[Int](info.parameters(2), mainArgs.drop(2)) // using a parser of Int
mainAnnot.run(() => sum(args0(), args1(), args2()*))
}
}
```

The implementation of the `main` method first instantiates the annotation and then creates a `Command`.
When creating the `Command`, the arguments can be checked and preprocessed.
The implementation of the `main` method first instantiates the annotation and then call `command`.
When calling the `command`, the arguments can be checked and preprocessed.
Then it defines a series of argument getters calling `argGetter` for each parameter and `varargGetter` for the last one if it is a varargs. `argGetter` gets an optional lambda that computes the default argument.
Finally, the `run` method is called to run the application. It receives a by-name argument that contains the call the annotated method with the instantiations arguments (using the lambdas from `argGetter`/`varargGetter`).

Expand All @@ -50,42 +52,46 @@ Example of implementation of `myMain` that takes all arguments positionally. It
// Parser used to parse command line arguments
import scala.util.CommandLineParser.FromString[T]

// Result type of the annotated method is Int
class myMain extends MainAnnotation:
import MainAnnotation.{ ParameterInfo, Command }
// Result type of the annotated method is Int and arguments are parsed using FromString
@experimental class myMain extends MainAnnotation[FromString, Int]:
import MainAnnotation.{ Info, Parameter }

/** A new command with arguments from `args` */
def command(info: CommandInfo, args: Array[String]): Command[FromString, Int] =
def command(info: Info, args: Seq[String]): Option[Seq[String]] =
if args.contains("--help") then
println(info.documentation)
// TODO: Print documentation of the parameters
System.exit(0)
assert(info.parameters.forall(!_.hasDefault), "Default arguments are not supported")
val (plainArgs, varargs) =
if info.parameters.last.isVarargs then
val numPlainArgs = info.parameters.length - 1
assert(numPlainArgs <= args.length, "Not enough arguments")
(args.take(numPlainArgs), args.drop(numPlainArgs))
None // do not parse or run the program
else if info.parameters.exists(_.hasDefault) then
println("Default arguments are not supported")
None
else if info.hasVarargs then
val numPlainArgs = info.parameters.length - 1
if numPlainArgs <= args.length then
println("Not enough arguments")
None
else
Some(args)
else
if info.parameters.length <= args.length then
println("Not enough arguments")
None
else if info.parameters.length >= args.length then
println("Too many arguments")
None
else
assert(info.parameters.length <= args.length, "Not enough arguments")
assert(info.parameters.length >= args.length, "Too many arguments")
(args, Array.empty[String])
new MyCommand(plainArgs, varargs)
Some(args)

@experimental
class MyCommand(plainArgs: Seq[String], varargs: Seq[String]) extends Command[FromString, Int]:
def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T =
() => parser.fromString(arg)

def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T =
() => parser.fromString(plainArgs(idx))
def varargGetter[T](param: Parameter, args: Seq[String])(using parser: FromString[T]): () => Seq[T] =
() => args.map(arg => parser.fromString(arg))

def varargGetter[T](using parser: FromString[T]): () => Seq[T] =
() => varargs.map(arg => parser.fromString(arg))
def run(program: () => Int): Unit =
println("executing program")

def run(program: () => Int): Unit =
println("executing program")
try {
val result = program()
println("result: " + result)
println("executed program")
end MyCommand
end myMain
```
Loading

0 comments on commit 813e059

Please sign in to comment.