From 3fe797ceb47d4702a59d6077528805477edf8a27 Mon Sep 17 00:00:00 2001 From: Albert Meltzer <7529386+kitbellew@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:03:56 -0800 Subject: [PATCH] AvoidInfix: process scala3 `match` operator, too --- .../org/scalafmt/rewrite/AvoidInfix.scala | 33 +++++++++++++++---- .../test/resources/scala3/OptionalBraces.stat | 12 ++++--- .../resources/scala3/OptionalBraces_fold.stat | 4 +-- .../resources/scala3/OptionalBraces_keep.stat | 6 ++-- .../scala3/OptionalBraces_unfold.stat | 12 +++---- .../test/scala/org/scalafmt/FormatTests.scala | 2 +- 6 files changed, 46 insertions(+), 23 deletions(-) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala index 49987843c..b74efb997 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala @@ -21,6 +21,7 @@ object AvoidInfix extends RewriteFactory { class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { private val cfg = ctx.style.rewrite.avoidInfix + private val allowMatchAsOperator = dialect.allowMatchAsOperator // In a perfect world, we could just use // Tree.transform { @@ -30,18 +31,26 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { // we will do these dangerous rewritings by hand. override def rewrite(tree: Tree): Unit = tree match { - case x: Term.ApplyInfix => rewriteImpl(x.lhs, x.op, x.arg, x.targClause) + case x: Term.ApplyInfix => + rewriteImpl(x.lhs, Right(x.op), x.arg, x.targClause) case x: Term.Select if !cfg.excludePostfix && noDot(x.name.tokens.head) => - rewriteImpl(x.qual, x.name) + rewriteImpl(x.qual, Right(x.name)) + case x: Term.Match => noDotMatch(x) + .foreach(op => rewriteImpl(x.expr, Left(op), null)) case _ => } private def noDot(opToken: T): Boolean = !ctx.tokenTraverser.prevNonTrivialToken(opToken).forall(_.is[T.Dot]) + private def noDotMatch(t: Term.Match): Option[T] = + if (allowMatchAsOperator && !cfg.excludeMatch) ctx.tokenTraverser + .prevNonTrivialToken(t.casesBlock.tokens.head).filter(noDot) + else None + private def rewriteImpl( lhs: Term, - op: Name, + op: Either[T, Name], rhs: Tree = null, targs: Member.SyntaxValuesClause = null, ): Unit = { @@ -52,19 +61,22 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { val lhsIsOK = lhsIsWrapped || (lhs match { case y: Term.ApplyInfix => checkMatchingInfix(y.lhs, y.op.value, y.arg) + case y: Term.Match => noDotMatch(y) + .exists(kw => checkMatchingInfix(y.expr, kw.text, y.casesBlock)) case _ => false }) - if (!checkMatchingInfix(lhs, op.value, rhs, Some(lhsIsOK))) return + if (!checkMatchingInfix(lhs, op.fold(_.text, _.value), rhs, Some(lhsIsOK))) + return if (!ctx.dialect.allowTryWithAnyExpr) if (beforeLhsHead.exists(_.is[T.KwTry])) return val builder = Seq.newBuilder[TokenPatch] - val (opHead, opLast) = ends(op) + val (opHead, opLast) = op.fold((_, null), ends) builder += TokenPatch.AddLeft(opHead, ".", keepTok = true) - if (rhs ne null) { + if ((rhs ne null) && (opLast ne null)) { def moveOpenDelim(prev: T, open: T): Unit = { // move delimiter (before comment or newline) builder += TokenPatch.AddRight(prev, open.text, keepTok = true) @@ -124,6 +136,15 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { case None => isWrapped(lhs) || checkMatchingInfix(lhs.lhs, lhs.op.value, lhs.arg) } + case lhs: Term.Match if hasPlaceholder(lhs, includeArg = true) => + lhsIsOK match { + case Some(x) => x + case None if isWrapped(lhs) => true + case None => noDotMatch(lhs) match { + case Some(op) => checkMatchingInfix(lhs.expr, op.text, null) + case _ => true + } + } case _ => true }) diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat index 1a6d07170..13f0e4326 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat @@ -7587,8 +7587,9 @@ object a { } >>> object a { - a.b(c) match - case _ => + a.b(c) + .match + case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch rewrite.rules = [AvoidInfix] @@ -7616,9 +7617,10 @@ object a { } >>> object a { - (a.b(c) match { - case _ => - }).d(e) + a.b(c) + .match { case _ => + } + .d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat index e59c7ea44..95e2a0058 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat @@ -7298,7 +7298,7 @@ object a { } >>> object a { - a.b(c) match + a.b(c).match case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch @@ -7325,7 +7325,7 @@ object a { } >>> object a { - (a.b(c) match { case _ => }).d(e) + a.b(c).match { case _ => }.d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat index 259518736..e97571659 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat @@ -7615,7 +7615,7 @@ object a { } >>> object a { - a.b(c) match + a.b(c).match case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch @@ -7644,9 +7644,9 @@ object a { } >>> object a { - (a.b(c) match { + a.b(c).match { case _ => - }).d(e) + }.d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat index f0921787b..24a2ad72a 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat @@ -7902,8 +7902,9 @@ object a { } >>> object a { - a.b(c) match - case _ => + a.b(c) + .match + case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch rewrite.rules = [AvoidInfix] @@ -7933,11 +7934,10 @@ object a { } >>> object a { - ( - a.b(c) match { - case _ => + a.b(c) + .match { case _ => } - ).d(e) + .d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] diff --git a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala index c97355882..ef8ab54c3 100644 --- a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala +++ b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions { val explored = Debug.explored.get() logger.debug(s"Total explored: $explored") if (!onlyUnit && !onlyManual) - assertEquals(explored, 1083996, "total explored") + assertEquals(explored, 1084044, "total explored") val results = debugResults.result() // TODO(olafur) don't block printing out test results. // I don't want to deal with scalaz's Tasks :'(