diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala index 347b6c8c49..52a7efb2d7 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/FormatTokensRewrite.scala @@ -133,7 +133,7 @@ class FormatTokensRewrite( if (formatOff) None else session.claimedRule match { - case x @ Some(rule) => if (applyRule(rule)) x else None + case Some(c) => if (applyRule(c.rule)) Some(c.rule) else None case _ => applyRules } leftDelimIndex.prepend((ldelimIdx, ruleOpt)) @@ -147,7 +147,7 @@ class FormatTokensRewrite( val replacement = ruleOpt match { case Some(rule) if !ft.meta.formatOff && - session.claimedRule.forall(_ eq rule) => + session.claimedRule.forall(_.rule eq rule) => implicit val style = styleMap.at(ft.right) if (rule.enabled) rule.onRight(tokens(ldelimIdx), formatOff) else None @@ -245,20 +245,26 @@ object FormatTokensRewrite { private[rewrite] class Session(rules: Seq[Rule]) { private implicit val implicitSession: Session = this - private val claimed = new mutable.HashMap[Int, Rule]() + private val claimed = new mutable.HashMap[Int, Claimant]() private[FormatTokensRewrite] val tokens = new mutable.ArrayBuffer[Replacement]() - def claimedRule(implicit ft: FormatToken): Option[Rule] = - claimed.get(ft.meta.idx) + @inline + def claimedRule(implicit ft: FormatToken): Option[Claimant] = + claimedRule(ft.meta.idx) + + @inline + private[rewrite] def claimedRule(ftIdx: Int): Option[Claimant] = + claimed.get(ftIdx) private[FormatTokensRewrite] def applyRule( rule: Rule )(implicit ft: FormatToken, style: ScalafmtConfig): Boolean = rule.enabled && (rule.onToken match { case Some(repl) => - claimed.getOrElseUpdate(ft.meta.idx, rule) - repl.claim.foreach { claimed.getOrElseUpdate(_, rule) } + val claimant = new Claimant(rule, repl) + claimed.getOrElseUpdate(ft.meta.idx, claimant) + repl.claim.foreach { claimed.getOrElseUpdate(_, claimant) } tokens.append(repl) true case _ => false @@ -279,6 +285,11 @@ object FormatTokensRewrite { rules.find(tag.runtimeClass.isInstance).map(_.asInstanceOf[A]) } + private[rewrite] class Claimant( + val rule: Rule, + val replacement: Replacement + ) + private[rewrite] class Replacement( val ft: FormatToken, val how: ReplacementType, diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala index 4c0b6c2dba..8ca2479c48 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala @@ -91,7 +91,7 @@ class RedundantParens(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { case _: Token.RightParen if (left.how eq ReplacementType.Remove) && { val maybeCommaFt = ftoks.prevNonComment(ft) !maybeCommaFt.left.is[Token.Comma] || - session.claimedRule(ftoks.prev(maybeCommaFt)).isDefined + session.claimedRule(maybeCommaFt.meta.idx - 1).isDefined } /* check for trailing comma */ => Some((left, removeToken)) case _ => None diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala index ef31ff7f35..3bc0eee884 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RewriteTrailingCommas.scala @@ -46,8 +46,8 @@ private class RewriteTrailingCommas(ftoks: FormatTokens) case rp: Token.RightParen => rightOwner.isAny[Member.SyntaxValuesClause, Member.Tuple] || ftoks.matchingOpt(rp).exists { lp => - val rule = session.claimedRule(ftoks.justBefore(lp)) - rule.forall(_.isInstanceOf[RedundantParens]) + val claimant = session.claimedRule(ftoks.justBefore(lp)) + claimant.forall(_.rule.isInstanceOf[RedundantParens]) } case _: Token.RightBracket =>