diff --git a/cli/src/main/scala/scalafix/cli/ArgParserImplicits.scala b/cli/src/main/scala/scalafix/cli/ArgParserImplicits.scala index 7c053249b..37867b12b 100644 --- a/cli/src/main/scala/scalafix/cli/ArgParserImplicits.scala +++ b/cli/src/main/scala/scalafix/cli/ArgParserImplicits.scala @@ -1,7 +1,7 @@ package scalafix.cli -import scalafix.rewrite.ProcedureSyntax import scalafix.rewrite.Rewrite +import scalafix.rewrite.VolatileLazyVal import java.io.InputStream import java.io.PrintStream @@ -9,20 +9,14 @@ import java.io.PrintStream import caseapp.core.ArgParser object ArgParserImplicits { - def nameMap[T](t: sourcecode.Text[T]*): Map[String, T] = { - t.map(x => x.source -> x.value).toMap - } - val rewriteMap: Map[String, Rewrite] = nameMap( - ProcedureSyntax - ) implicit val rewriteRead: ArgParser[Rewrite] = ArgParser.instance[Rewrite] { str => - rewriteMap.get(str) match { + Rewrite.name2rewrite.get(str) match { case Some(x) => Right(x) case _ => - Left( - s"invalid input $str, must be one of ${rewriteMap.keys.mkString(", ")}") + val availableKeys = Rewrite.name2rewrite.keys.mkString(", ") + Left(s"invalid input $str, must be one of $availableKeys") } } diff --git a/cli/src/main/scala/scalafix/cli/Cli.scala b/cli/src/main/scala/scalafix/cli/Cli.scala index 106fd43b1..1ac1a0e83 100644 --- a/cli/src/main/scala/scalafix/cli/Cli.scala +++ b/cli/src/main/scala/scalafix/cli/Cli.scala @@ -1,7 +1,7 @@ package scalafix.cli import scala.collection.GenSeq -import scalafix.FixResult +import scalafix.Fixed import scalafix.Scalafix import scalafix.cli.ArgParserImplicits._ import scalafix.rewrite.Rewrite @@ -30,7 +30,7 @@ case class CommonOptions( case class ScalafixOptions( @HelpMessage( s"Rules to run, one of: ${Rewrite.default.mkString(", ")}" - ) rewrites: List[Rewrite] = Rewrite.default, + ) rewrites: List[Rewrite] = Rewrite.default.toList, @Hidden @HelpMessage( "Files to fix. Runs on all *.scala files if given a directory." ) @ExtraName("f") files: List[String] = List.empty[String], @@ -56,13 +56,13 @@ object Cli extends AppOf[ScalafixOptions] { def handleFile(file: File, config: ScalafixOptions): Unit = { Scalafix.fix(FileOps.readFile(file), config.rewrites) match { - case FixResult.Success(code) => + case Fixed.Success(code) => if (config.inPlace) { FileOps.writeFile(file, code) } else config.common.out.write(code.getBytes) - case FixResult.Failure(e) => + case Fixed.Failure(e) => config.common.err.write(s"Failed to fix $file. Cause: $e".getBytes) - case e: FixResult.ParseError => + case e: Fixed.ParseError => if (config.files.contains(file)) { // Only log if user explicitly specified that file. config.common.err.write(e.toString.getBytes()) diff --git a/cli/src/main/scala/scalafix/cli/Scalafix210.scala b/cli/src/main/scala/scalafix/cli/Scalafix210.scala index 8b5e5843b..7f5c51e06 100644 --- a/cli/src/main/scala/scalafix/cli/Scalafix210.scala +++ b/cli/src/main/scala/scalafix/cli/Scalafix210.scala @@ -1,6 +1,6 @@ package scalafix.cli -import scalafix.FixResult +import scalafix.Fixed import scalafix.Scalafix import scalafix.rewrite.Rewrite import scalafix.util.logger @@ -8,8 +8,8 @@ import scalafix.util.logger class Scalafix210 { def fix(originalContents: String, filename: String): String = { Scalafix.fix(originalContents, Rewrite.default) match { - case FixResult.Success(fixedCode) => fixedCode - case FixResult.Error(e) => + case Fixed.Success(fixedCode) => fixedCode + case Fixed.Failure(e) => logger.warn(s"Failed to fix $filename. Cause ${e.getMessage}") originalContents } diff --git a/core/src/main/scala/scalafix/FixResult.scala b/core/src/main/scala/scalafix/FixResult.scala deleted file mode 100644 index 0ef3cf20e..000000000 --- a/core/src/main/scala/scalafix/FixResult.scala +++ /dev/null @@ -1,25 +0,0 @@ -package scalafix - -import scala.meta.inputs.Position - -abstract sealed class FixResult { - def get: String = this match { - case FixResult.Success(code) => code - case FixResult.Failure(e) => throw e - case e: FixResult.ParseError => throw e.exception - } -} - -object FixResult { - case class Success(code: String) extends FixResult - object Error { - def unapply(fixResult: FixResult): Option[Throwable] = fixResult match { - case Failure(e) => Some(e) - case e: ParseError => Some(e.exception) - case _ => None - } - } - case class Failure(e: Throwable) extends FixResult - case class ParseError(pos: Position, message: String, exception: Throwable) - extends FixResult -} diff --git a/core/src/main/scala/scalafix/Fixed.scala b/core/src/main/scala/scalafix/Fixed.scala new file mode 100644 index 000000000..8152e0777 --- /dev/null +++ b/core/src/main/scala/scalafix/Fixed.scala @@ -0,0 +1,24 @@ +package scalafix + +import scala.meta.inputs.Position + +abstract sealed class Fixed { + def get: String = this match { + case Fixed.Success(code) => code + case Fixed.Failure(e) => throw e + } +} + +object Fixed { + case class Success(code: String) extends Fixed + class Failure(val e: Throwable) extends Fixed + object Failure { + def apply(exception: Throwable): Failure = new Failure(exception) + def unapply(arg: Failure): Option[Throwable] = { + Some(arg.e) + } + } + + case class ParseError(pos: Position, message: String, exception: Throwable) + extends Failure(exception) +} diff --git a/core/src/main/scala/scalafix/Scalafix.scala b/core/src/main/scala/scalafix/Scalafix.scala index fbf0c3903..e0f69c07e 100644 --- a/core/src/main/scala/scalafix/Scalafix.scala +++ b/core/src/main/scala/scalafix/Scalafix.scala @@ -5,17 +5,17 @@ import scala.util.control.NonFatal import scalafix.rewrite.Rewrite object Scalafix { - def fix(code: String, rewriters: Seq[Rewrite] = Rewrite.default): FixResult = { + def fix(code: String, rewriters: Seq[Rewrite] = Rewrite.default): Fixed = { fix(Input.String(code), rewriters) } - def fix(code: Input, rewriters: Seq[Rewrite]): FixResult = { - rewriters.foldLeft[FixResult]( - FixResult.Success(String.copyValueOf(code.chars))) { - case (newCode: FixResult.Success, rewriter) => + def fix(code: Input, rewriters: Seq[Rewrite]): Fixed = { + rewriters.foldLeft[Fixed]( + Fixed.Success(String.copyValueOf(code.chars))) { + case (newCode: Fixed.Success, rewriter) => try rewriter.rewrite(Input.String(newCode.code)) catch { - case NonFatal(e) => FixResult.Failure(e) + case NonFatal(e) => Fixed.Failure(e) } case (failure, _) => failure } diff --git a/core/src/main/scala/scalafix/rewrite/ProcedureSyntax.scala b/core/src/main/scala/scalafix/rewrite/ProcedureSyntax.scala index e066afa29..44ce29473 100644 --- a/core/src/main/scala/scalafix/rewrite/ProcedureSyntax.scala +++ b/core/src/main/scala/scalafix/rewrite/ProcedureSyntax.scala @@ -1,10 +1,10 @@ package scalafix.rewrite import scala.meta._ -import scalafix.FixResult +import scalafix.Fixed object ProcedureSyntax extends Rewrite { - override def rewrite(code: Input): FixResult = { + override def rewrite(code: Input): Fixed = { withParsed(code) { ast => val toPrepend = ast.collect { case t: Defn.Def if t.decltpe.exists(_.tokens.isEmpty) => @@ -21,7 +21,7 @@ object ProcedureSyntax extends Rewrite { sb.append(token.syntax) } val result = sb.toString() - FixResult.Success(result) + Fixed.Success(result) } } } diff --git a/core/src/main/scala/scalafix/rewrite/Rewrite.scala b/core/src/main/scala/scalafix/rewrite/Rewrite.scala index 274e79eb1..474aebe4c 100644 --- a/core/src/main/scala/scalafix/rewrite/Rewrite.scala +++ b/core/src/main/scala/scalafix/rewrite/Rewrite.scala @@ -1,23 +1,30 @@ package scalafix.rewrite import scala.meta._ -import scalafix.FixResult +import scalafix.Fixed abstract class Rewrite { - def rewrite(code: Input): FixResult + def rewrite(code: Input): Fixed - protected def withParsed(code: Input)(f: Tree => FixResult): FixResult = { + protected def withParsed(code: Input)(f: Tree => Fixed): Fixed = { code.parse[Source] match { case Parsed.Success(ast) => f(ast) case Parsed.Error(pos, msg, details) => - FixResult.ParseError(pos, msg, details) + Fixed.ParseError(pos, msg, details) } } } object Rewrite { - val default: List[Rewrite] = List( - ProcedureSyntax + private def nameMap[T](t: sourcecode.Text[T]*): Map[String, T] = { + t.map(x => x.source -> x.value).toMap + } + + val name2rewrite: Map[String, Rewrite] = nameMap[Rewrite]( + ProcedureSyntax, + VolatileLazyVal ) + + val default: Seq[Rewrite] = name2rewrite.values.toSeq } diff --git a/core/src/main/scala/scalafix/rewrite/VolatileLazyVal.scala b/core/src/main/scala/scalafix/rewrite/VolatileLazyVal.scala new file mode 100644 index 000000000..d33514d67 --- /dev/null +++ b/core/src/main/scala/scalafix/rewrite/VolatileLazyVal.scala @@ -0,0 +1,34 @@ +package scalafix.rewrite + +import scala.meta._ +import scalafix.Fixed +import scalafix.util.logger + +object VolatileLazyVal extends Rewrite { + private object NonVolatileLazyVal { + def unapply(defn: Defn.Val): Option[Token] = { + defn.mods.collectFirst { + case x if x.syntax == "@volatile" => + None + case x if x.syntax == "lazy" => + Some(x.tokens.head) + } + }.flatten + } + override def rewrite(code: Input): Fixed = { + withParsed(code) { ast => + val toPrepend: Seq[Token] = ast.collect { + case NonVolatileLazyVal(tok) => tok + } + val sb = new StringBuilder + ast.tokens.foreach { token => + if (toPrepend.contains(token)) { + sb.append("@volatile ") + } + sb.append(token.syntax) + } + val result = sb.toString() + Fixed.Success(result) + } + } +} diff --git a/core/src/test/scala/scalafix/ScalafixSuite.scala b/core/src/test/scala/scalafix/ScalafixSuite.scala index c0bf35b1f..603f04b06 100644 --- a/core/src/test/scala/scalafix/ScalafixSuite.scala +++ b/core/src/test/scala/scalafix/ScalafixSuite.scala @@ -51,7 +51,7 @@ class ScalafixSuite extends FunSuite with DiffAssertions { test("on parse error") { val obtained = Scalafix.fix("object A {") - assert(obtained.isInstanceOf[FixResult.ParseError]) + assert(obtained.isInstanceOf[Fixed.ParseError]) } } diff --git a/core/src/test/scala/scalafix/rewrite/LazyValSuite.scala b/core/src/test/scala/scalafix/rewrite/LazyValSuite.scala new file mode 100644 index 000000000..861732941 --- /dev/null +++ b/core/src/test/scala/scalafix/rewrite/LazyValSuite.scala @@ -0,0 +1,57 @@ +package scalafix.rewrite + +import scala.meta.inputs.Input +import scalafix.Fixed +import scalafix.util.DiffAssertions + +import org.scalatest.FunSuiteLike + +class RewriteSuite(rewrite: Rewrite) extends FunSuiteLike with DiffAssertions { + + def rewriteTest(name: String, original: String, expected: String): Unit = { + test(name) { + rewrite.rewrite(Input.String(original)) match { + case Fixed.Success(obtained) => + assertNoDiff(obtained, expected) + case Fixed.Failure(e) => + throw e + } + } + } + +} + +class LazyValSuite extends RewriteSuite(VolatileLazyVal) { + + rewriteTest( + "basic", + """|object a { + | + |val foo = 1 + | + | lazy val x = 2 + | @volatile lazy val dontChangeMe = 2 + | + | class foo { + | lazy val z = { + | reallyHardStuff() + | } + | } + |} + """.stripMargin, + """|object a { + | + |val foo = 1 + | + | @volatile lazy val x = 2 + | @volatile lazy val dontChangeMe = 2 + | + | class foo { + | @volatile lazy val z = { + | reallyHardStuff() + | } + | } + |} + """.stripMargin + ) +}