Skip to content

Commit

Permalink
BestFirstSearch: skip processOptimal if too long
Browse files Browse the repository at this point in the history
An optimal token would succeed if all splits cost 0, and there are no
disallowed overflow tokens. In practice, this means that the span before
the optimal token can't be way much longer than the line length.
  • Loading branch information
kitbellew committed Oct 29, 2024
1 parent 188cc88 commit 1faeaed
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,30 +169,45 @@ private class BestFirstSearch private (range: Set[Range])(implicit
state.next(split)
}

private def killOnFail(opt: OptimalToken, nextNextState: State = null)(
implicit nextState: State,
): State = {
val kill = opt.killOnFail || nextState.hasSlbUntil {
if (
(null ne nextNextState) &&
nextNextState.appliedPenalty > nextState.prev.appliedPenalty
) tokens(nextNextState.depth)
else tokens(opt.token)
}
private def killOnFail(
isKillOnFail: Boolean,
)(end: => FormatToken)(implicit nextState: State): State = {
val kill = isKillOnFail || nextState.hasSlbUntil(end)
if (kill) null else nextState
}

private def processOptimalToken(
private def killOnFail(
opt: OptimalToken,
)(implicit nextState: State, queue: StateQueue): Either[State, State] = {
end: => FormatToken = null,
nextNextState: State = null,
)(implicit nextState: State): State = killOnFail(opt.killOnFail) {
if (
(null ne nextNextState) &&
nextNextState.appliedPenalty > nextState.prev.appliedPenalty
) tokens(nextNextState.depth)
else {
val optEnd = end
if (optEnd ne null) optEnd else tokens(opt.token)
}
}

private def processOptimalToken(opt: OptimalToken)(implicit
nextState: State,
queue: StateQueue,
style: ScalafmtConfig,
): Either[State, State] = {
val optEnd = tokens(opt.token)
val nextNextState =
if (opt.token.end <= tokens(nextState.depth).left.end) nextState
else if (
tokens.width(tokens(nextState.depth), optEnd) > 3 * style.maxColumn
) return Left(killOnFail(opt.killOnFail)(optEnd))
else {
val res =
shortestPath(nextState, opt.token, queue.nested + 1, isOpt = true)
res match {
case Right(x) => x
case Left(x) => return Left(killOnFail(opt, x))
case Left(x) => return Left(killOnFail(opt, optEnd, x))
}
}
def checkPenalty(state: State, orElse: => Either[State, State]) =
Expand All @@ -201,7 +216,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
else orElse
traverseSameLine(nextNextState) match {
case x @ Left(s) =>
if (s eq null) Left(killOnFail(opt, nextNextState))
if (s eq null) Left(killOnFail(opt, optEnd, nextNextState))
else checkPenalty(s, x)
case x @ Right(s) => checkPenalty(s, if (opt.recurseOnly) Left(s) else x)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract class CommunityIntellijScalaSuite(name: String)
class CommunityIntellijScala_2024_2_Suite
extends CommunityIntellijScalaSuite("intellij-scala-2024.2") {

override protected def totalStatesVisited: Option[Int] = Some(47828843)
override protected def totalStatesVisited: Option[Int] = Some(47152559)

override protected def builds = Seq(getBuild(
"2024.2.28",
Expand Down Expand Up @@ -51,7 +51,7 @@ class CommunityIntellijScala_2024_2_Suite
class CommunityIntellijScala_2024_3_Suite
extends CommunityIntellijScalaSuite("intellij-scala-2024.3") {

override protected def totalStatesVisited: Option[Int] = Some(48007538)
override protected def totalStatesVisited: Option[Int] = Some(47328521)

override protected def builds = Seq(getBuild(
"2024.3.4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract class CommunityScala2Suite(name: String)

class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") {

override protected def totalStatesVisited: Option[Int] = Some(35257028)
override protected def totalStatesVisited: Option[Int] = Some(34606038)

override protected def builds =
Seq(getBuild("v2.12.20", dialects.Scala212, 1277))
Expand All @@ -18,7 +18,7 @@ class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") {

class CommunityScala2_13Suite extends CommunityScala2Suite("scala-2.13") {

override protected def totalStatesVisited: Option[Int] = Some(43926266)
override protected def totalStatesVisited: Option[Int] = Some(43123923)

override protected def builds =
Seq(getBuild("v2.13.14", dialects.Scala213, 1287))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ abstract class CommunityScala3Suite(name: String)

class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") {

override protected def totalStatesVisited: Option[Int] = Some(32909474)
override protected def totalStatesVisited: Option[Int] = Some(32241176)

override protected def builds = Seq(getBuild("3.2.2", dialects.Scala32, 791))

}

class CommunityScala3_3Suite extends CommunityScala3Suite("scala-3.3") {

override protected def totalStatesVisited: Option[Int] = Some(35508610)
override protected def totalStatesVisited: Option[Int] = Some(34804029)

override protected def builds = Seq(getBuild("3.3.3", dialects.Scala33, 861))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract class CommunitySparkSuite(name: String)

class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") {

override protected def totalStatesVisited: Option[Int] = Some(71532918)
override protected def totalStatesVisited: Option[Int] = Some(70329765)

override protected def builds =
Seq(getBuild("v3.4.1", dialects.Scala213, 2585))
Expand All @@ -18,7 +18,7 @@ class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") {

class CommunitySpark3_5Suite extends CommunitySparkSuite("spark-3.5") {

override protected def totalStatesVisited: Option[Int] = Some(75670780)
override protected def totalStatesVisited: Option[Int] = Some(74403516)

override protected def builds =
Seq(getBuild("v3.5.3", dialects.Scala213, 2756))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ private def withNewLocalDefs = {
}))
}))
}
>>> { stateVisits = 3750 }
>>> { stateVisits = 3618 }
val createIsArrayOfStat = {
envFieldDef(
"isArrayOf",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class ResolutionCopier(x: Int) {
toNode.rightBracket));
}
}
>>> { stateVisits = 57624 }
>>> { stateVisits = 57581 }
class ResolutionCopier(x: Int) {

def visitClassDeclaration(node: ClassDeclaration): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private[parser] trait CacheControlHeader { this: Parser with CommonRules with Co
clearSB() ~ zeroOrMore(!'"' ~ !',' ~ qdtext ~ appendSB() | `quoted-pair`) ~ push(sb.toString)
}
}
>>> { stateVisits = 36215, stateVisits2 = 36215 }
>>> { stateVisits = 36181, stateVisits2 = 36181 }
/** Copyright (C) 2009-2016 Lightbend Inc. <http://www.lightbend.com>
*/

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9353,7 +9353,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil
}
}
>>> { stateVisits = 2846, stateVisits2 = 1042 }
>>> { stateVisits = 2768, stateVisits2 = 1038 }
class UDFRegistration private[sql] (
functionRegistry: FunctionRegistry
) extends Logging {
Expand Down Expand Up @@ -9584,7 +9584,7 @@ class UDFRegistration {
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]).toOption :: Try(ExpressionEncoder[A2]).toOption :: Try(ExpressionEncoder[A3]).toOption :: Try(ExpressionEncoder[A4]).toOption :: Try(ExpressionEncoder[A5]).toOption :: Try(ExpressionEncoder[A6]).toOption :: Try(ExpressionEncoder[A7]).toOption :: Try(ExpressionEncoder[A8]).toOption :: Try(ExpressionEncoder[A9]).toOption :: Try(ExpressionEncoder[A10]).toOption :: Try(ExpressionEncoder[A11]).toOption :: Try(ExpressionEncoder[A12]).toOption :: Try(ExpressionEncoder[A13]).toOption :: Try(ExpressionEncoder[A14]).toOption :: Try(ExpressionEncoder[A15]).toOption :: Try(ExpressionEncoder[A16]).toOption :: Try(ExpressionEncoder[A17]).toOption :: Try(ExpressionEncoder[A18]).toOption :: Try(ExpressionEncoder[A19]).toOption :: Try(ExpressionEncoder[A20]).toOption :: Try(ExpressionEncoder[A21]).toOption :: Try(ExpressionEncoder[A22]).toOption :: Nil
}
}
>>> { stateVisits = 2347, stateVisits2 = 752 }
>>> { stateVisits = 2292, stateVisits2 = 748 }
class UDFRegistration {
def foo = {
val inputEncoders: Seq[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8363,7 +8363,7 @@ object a {
)
)
}
>>> { stateVisits = 4946, stateVisits2 = 4946 }
>>> { stateVisits = 4788, stateVisits2 = 4788 }
object a {
div(cls := "cover")(
div(cls := "doc")(bodyContents),
Expand Down Expand Up @@ -8450,7 +8450,7 @@ object a {
)
)
}
>>> { stateVisits = 5051, stateVisits2 = 5051 }
>>> { stateVisits = 4849, stateVisits2 = 4849 }
object a {
div(cls := "cover")(
div(cls := "doc")(bodyContents),
Expand Down Expand Up @@ -9460,7 +9460,7 @@ object a {
}
}
}
>>> { stateVisits = 2821, stateVisits2 = 2821 }
>>> { stateVisits = 2428, stateVisits2 = 2428 }
object a {
private object MemoMap {
def make(implicit trace: Trace): UIO[MemoMap] = Ref.Synchronized
Expand Down Expand Up @@ -9578,7 +9578,7 @@ object a {
.map(_.filterNot(_.getCanonicalPath.contains("SSLOptions")))
}
}
>>> { stateVisits = 187773, stateVisits2 = 187773 }
>>> { stateVisits = 187751, stateVisits2 = 187751 }
object a {
private def ignoreUndocumentedPackages(
packages: Seq[Seq[File]]
Expand Down Expand Up @@ -9676,7 +9676,7 @@ object a {
.map(filterNot(getCanonicalPath.contains("SSLOptions")))
}
}
>>> { stateVisits = 12063, stateVisits2 = 12063 }
>>> { stateVisits = 12046, stateVisits2 = 12046 }
object a {
private def ignoreUndocumentedPackages(
packages: Seq[Seq[File]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions {
val explored = Debug.explored.get()
logger.debug(s"Total explored: $explored")
if (!onlyUnit && !onlyManual)
assertEquals(explored, 1484918, "total explored")
assertEquals(explored, 1473669, "total explored")
val results = debugResults.result()
// TODO(olafur) don't block printing out test results.
// I don't want to deal with scalaz's Tasks :'(
Expand Down

0 comments on commit 1faeaed

Please sign in to comment.