Skip to content

Commit

Permalink
Scalafmt: don't parse shebang line as scala
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Oct 3, 2023
1 parent 5641152 commit d8ffb3f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
18 changes: 16 additions & 2 deletions scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,33 @@ object Scalafmt {
)
}

private[scalafmt] def splitCodePrefix(input: String): (String, String) =
if (!input.startsWith("#!")) ("", input)
else {
val beforeNL = input.indexOf(UnixLineEnding, 2)
if (beforeNL < 0) (input + UnixLineEnding, "")
else {
val afterNL = beforeNL + UnixLineEnding.length
val hasBlank = input.startsWith(UnixLineEnding, afterNL)
val idx = if (hasBlank) afterNL + UnixLineEnding.length else afterNL
(input.substring(0, idx), input.substring(idx))
}
}

private def formatCodeWithStyle(
code: String,
style: ScalafmtConfig,
range: Set[Range],
filename: String
): Formatted.Result = {
val isWin = code.contains(WinLineEnding)
val unixCode =
val (prefix, unixCode) = splitCodePrefix(
if (isWin) code.replaceAll(WinLineEnding, UnixLineEnding) else code
)
doFormat(unixCode, style, filename, range) match {
case Failure(e) => Formatted.Result(Formatted.Failure(e), style)
case Success(x) =>
val s = if (x.isEmpty) UnixLineEnding else x
val s = if (prefix.isEmpty && x.isEmpty) UnixLineEnding else prefix + x
val asWin = style.lineEndings == LineEndings.windows ||
(isWin && style.lineEndings == LineEndings.preserve)
val res = if (asWin) s.replaceAll(UnixLineEnding, WinLineEnding) else s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,9 @@ class Router(formatOps: FormatOps) {
case FormatToken(_: T.EOF, _, _) => Seq(Split(Newline, 0))
case ft @ FormatToken(_, _: T.BOF, _) =>
Seq(Split(NoSplit.orNL(next(ft).right.is[T.EOF]), 0))
case FormatToken(_: T.BOF, right, _) =>
val policy = right match {
case T.Ident(name) if name.startsWith("#!") =>
val nl = findFirst(next(formatToken), Int.MaxValue) { x =>
x.hasBreak || x.right.is[T.EOF]
}
nl.fold(Policy.noPolicy) { ft =>
Policy.on(ft.left) { case Decision(t, _) =>
Seq(Split(Space(t.between.nonEmpty), 0))
}
}
case _ => Policy.NoPolicy
}
case FormatToken(_: T.BOF, _, _) =>
Seq(
Split(NoSplit, 0).withPolicy(policy)
Split(NoSplit, 0)
)
case FormatToken(_, _: T.EOF, _) =>
Seq(
Expand Down
8 changes: 4 additions & 4 deletions scalafmt-tests/src/test/resources/test/Dialect.source
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ addSbtPlugin( "io.get-coursier" % "sbt-coursier" % "1.0.0-M14")
addSbtPlugin( "com.eed3si9n" % "sbt-assembly" % "0.14.3")
addSbtPlugin( "org.brianmckenna" % "sbt-wartremover" % "0.14")
>>> foo.sc
#!/usr/bin/env amm
#!/usr/bin/env amm

addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-M14")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3")
Expand All @@ -51,7 +51,7 @@ rewrite.rules = [AvoidInfix]
===
#!/usr/bin/env amm
>>> foo.sc
#!/usr/bin/env amm
#!/usr/bin/env amm
<<< #2104
runner.dialect = scala3
===
Expand Down Expand Up @@ -232,6 +232,6 @@ foo(
def main(args: Array[String]): Unit =
println("hello")
>>>
test does not parse
#!/usr/bin/env -S scala @classpathAtfile
^
def main(args: Array[String]): Unit =
println("hello")
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ trait FormatAssertions {
obtained: String
)(implicit ev: Parse[T], dialect: Dialect): Unit = {
import scala.meta._
def toInput(code: String) = Scalafmt.toInput(code, filename)
def toInput(code: String) =
Scalafmt.toInput(Scalafmt.splitCodePrefix(code)._2, filename)
toInput(original).parse[T] match {
case Parsed.Error(pos, message, _) =>
val msgWithPos = pos.formatMessage("error", message)
Expand Down

0 comments on commit d8ffb3f

Please sign in to comment.