diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index 2ed180d26b..2adf1e8077 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -85,20 +85,31 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { ): Replacement = { val rt = ft.right val rtOwner = ft.meta.rightOwner - val ok = okToReplaceFunctionInSingleArgApply(rtOwner).exists { - case (lp, func) => lp.eq(rt) && func.tokens.last.is[Token.RightBrace] + def lpFunction = okToReplaceFunctionInSingleArgApply(rtOwner).map { + case (lp, f) if lp.eq(rt) && f.tokens.last.is[Token.RightBrace] => + replaceToken("{", Some(rtOwner)) { + new Token.LeftBrace(rt.input, rt.dialect, rt.start) + } + case _ => null } - if (ok) replaceToken("{", Some(rtOwner)) { - new Token.LeftBrace(rt.input, rt.dialect, rt.start) + // single-arg apply of a partial function + // a({ case b => c; d }) change to a { case b => c; d } + def lpPartialFunction = rtOwner match { + case ta @ Term.Apply(_, List(arg)) => + getOpeningParen(ta).map { lp => + val ko = lp.ne(rt) || getBlockNestedPartialFunction(arg).isEmpty + if (ko) null else removeToken + } + case _ => None } - else null + + lpFunction.orElse(lpPartialFunction).orNull } private def onRightParen( left: Replacement )(implicit ft: FormatToken): (Replacement, Replacement) = - if (left.exists(_.right.is[Token.LeftBrace])) (left, removeToken) - else null + (left, removeToken) private def onLeftBrace(implicit ft: FormatToken, @@ -114,7 +125,8 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { else if (okToReplaceFunctionInSingleArgApply(t)) replace else removeToken case t: Term.Block => - if (okToReplaceBlockInSingleArgApply(t)) replace + if (getBlockNestedPartialFunction(t).isDefined) removeToken + else if (okToReplaceBlockInSingleArgApply(t)) replace else if (processBlock(t)) removeToken else null case _: Term.Interpolate => @@ -410,6 +422,23 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { case _ => okIfMultipleStats }) + private def getBlockNestedPartialFunction( + tree: Tree + ): Option[Term.PartialFunction] = tree match { + case x: Term.PartialFunction => Some(x) + case x: Term.Block => getBlockNestedPartialFunction(x) + case _ => None + } + + @tailrec + private def getBlockNestedPartialFunction( + tree: Term.Block + ): Option[Term.PartialFunction] = getBlockSingleStat(tree) match { + case Some(x: Term.PartialFunction) => Some(x) + case Some(x: Term.Block) => getBlockNestedPartialFunction(x) + case _ => None + } + private def getSingleStatIfLineSpanOk(b: Term.Block)(implicit style: ScalafmtConfig ): Option[Stat] = diff --git a/scalafmt-tests/src/test/resources/rewrite/RedundantBraces.stat b/scalafmt-tests/src/test/resources/rewrite/RedundantBraces.stat index d1b070e0b8..ff2176c52b 100644 --- a/scalafmt-tests/src/test/resources/rewrite/RedundantBraces.stat +++ b/scalafmt-tests/src/test/resources/rewrite/RedundantBraces.stat @@ -1082,3 +1082,63 @@ object a { y } } +<<< single-block partial function with parens one-arg apply +object a { + val foo = bar ( + { case x => y } + ) + val foo = bar ( + { { case x => y } } + ) + val foo = bar ( + { { { case x => y } } } + ) + val foo = bar.baz ( + { case x => y } + ) + val foo = bar.baz ( + { { case x => y } } + ) + val foo = bar.baz ( + { { { case x => y } } } + ) +} +>>> +object a { + val foo = bar { case x => y } + val foo = bar { case x => y } + val foo = bar { case x => y } + val foo = bar.baz { case x => y } + val foo = bar.baz { case x => y } + val foo = bar.baz { case x => y } +} +<<< single-block partial function with block one-arg apply +object a { + val foo = bar { + { case x => y } + } + val foo = bar { + { { case x => y } } + } + val foo = bar { + { { { case x => y } } } + } + val foo = bar.baz { + { case x => y } + } + val foo = bar.baz { + { { case x => y } } + } + val foo = bar.baz { + { { { case x => y } } } + } +} +>>> +object a { + val foo = bar { case x => y } + val foo = bar { case x => y } + val foo = bar { case x => y } + val foo = bar.baz { case x => y } + val foo = bar.baz { case x => y } + val foo = bar.baz { case x => y } +}