-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Upgrade f-interpolator #13367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Upgrade f-interpolator #13367
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
ffeaa5f
Port Scala 2 f-interpolator
som-snytt fa9f19e
Cleanup dispatch to interpolations
som-snytt 67cddfa
Brace reduction and remove dead code
som-snytt 9826c44
Simplify Conversion
som-snytt 838db3e
Simplify TypedFormatChecker
som-snytt d05c62b
Drop special reporter
som-snytt f33198f
Improve error position and recovery for bad dollar
som-snytt 1b83c92
Pretty print type on f-interpolator, improve caret
som-snytt 07d809f
Use normal string escaping for f
som-snytt 0c87244
Collect error comments
som-snytt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
287 changes: 287 additions & 0 deletions
287
compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
package dotty.tools.dotc | ||
package transform.localopt | ||
|
||
import scala.annotation.tailrec | ||
import scala.collection.mutable.ListBuffer | ||
import scala.util.chaining.* | ||
import scala.util.matching.Regex.Match | ||
|
||
import java.util.{Calendar, Date, Formattable} | ||
|
||
import PartialFunction.cond | ||
|
||
import dotty.tools.dotc.ast.tpd.{Match => _, *} | ||
import dotty.tools.dotc.core.Contexts._ | ||
import dotty.tools.dotc.core.Symbols._ | ||
import dotty.tools.dotc.core.Types._ | ||
import dotty.tools.dotc.core.Phases.typerPhase | ||
import dotty.tools.dotc.util.Spans.Span | ||
|
||
/** Formatter string checker. */ | ||
class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List[Tree])(using Context): | ||
|
||
val argTypes = args.map(_.tpe) | ||
val actuals = ListBuffer.empty[Tree] | ||
|
||
// count of args, for checking indexes | ||
val argc = argTypes.length | ||
|
||
// Pick the first runtime type which the i'th arg can satisfy. | ||
// If conversion is required, implementation must emit it. | ||
def argType(argi: Int, types: Type*): Type = | ||
require(argi < argc, s"$argi out of range picking from $types") | ||
val tpe = argTypes(argi) | ||
types.find(t => argConformsTo(argi, tpe, t)) | ||
.orElse(types.find(t => argConvertsTo(argi, tpe, t))) | ||
.getOrElse { | ||
report.argError(s"Found: ${tpe.show}, Required: ${types.map(_.show).mkString(", ")}", argi) | ||
actuals += args(argi) | ||
types.head | ||
} | ||
|
||
object formattableTypes: | ||
val FormattableType = requiredClassRef("java.util.Formattable") | ||
val BigIntType = requiredClassRef("scala.math.BigInt") | ||
val BigDecimalType = requiredClassRef("scala.math.BigDecimal") | ||
val CalendarType = requiredClassRef("java.util.Calendar") | ||
val DateType = requiredClassRef("java.util.Date") | ||
import formattableTypes.* | ||
def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = (arg <:< target).tap(if _ then actuals += args(argi)) | ||
def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean = | ||
import typer.Implicits.SearchSuccess | ||
atPhase(typerPhase) { | ||
ctx.typer.inferView(args(argi), target) match | ||
case SearchSuccess(view, ref, _, _) => actuals += view ; true | ||
case _ => false | ||
} | ||
|
||
// match a conversion specifier | ||
val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r | ||
|
||
// ordinal is the regex group index in the format pattern | ||
enum SpecGroup: | ||
case Spec, Index, Flags, Width, Precision, CC | ||
import SpecGroup.* | ||
|
||
/** For N part strings and N-1 args to interpolate, normalize parts and check arg types. | ||
* | ||
* Returns normalized part strings and args, where args correcpond to conversions in tail of parts. | ||
*/ | ||
def checked: (List[String], List[Tree]) = | ||
val amended = ListBuffer.empty[String] | ||
val convert = ListBuffer.empty[Conversion] | ||
|
||
@tailrec | ||
def loop(remaining: List[String], n: Int): Unit = | ||
remaining match | ||
case part0 :: more => | ||
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage, index = n, offset = 0)) | ||
val part = try StringContext.processEscapes(part0) catch badPart | ||
val matches = formatPattern.findAllMatchIn(part) | ||
|
||
def insertStringConversion(): Unit = | ||
amended += "%s" + part | ||
convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve | ||
argType(n-1, defn.AnyType) | ||
def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}") | ||
def accept(op: Conversion): Unit = | ||
if !op.isLeading then errorLeading(op) | ||
op.accepts(argType(n-1, op.acceptableVariants*)) | ||
amended += part | ||
convert += op | ||
|
||
// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed | ||
if n == 0 then amended += part | ||
else if !matches.hasNext then insertStringConversion() | ||
else | ||
val cv = Conversion(matches.next(), n) | ||
if cv.isLiteral then insertStringConversion() | ||
else if cv.isIndexed then | ||
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion() | ||
else if !cv.isError then accept(cv) | ||
|
||
// any remaining conversions in this part must be either literals or indexed | ||
while matches.hasNext do | ||
val cv = Conversion(matches.next(), n) | ||
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg") | ||
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv) | ||
|
||
loop(more, n + 1) | ||
case Nil => () | ||
end loop | ||
|
||
loop(parts, n = 0) | ||
if reported then (Nil, Nil) | ||
else | ||
assert(argc == actuals.size, s"Expected ${argc} args but got ${actuals.size} for [${parts.mkString(", ")}]") | ||
(amended.toList, actuals.toList) | ||
end checked | ||
|
||
extension (descriptor: Match) | ||
def at(g: SpecGroup): Int = descriptor.start(g.ordinal) | ||
def end(g: SpecGroup): Int = descriptor.end(g.ordinal) | ||
def offset(g: SpecGroup, i: Int = 0): Int = at(g) + i | ||
def group(g: SpecGroup): Option[String] = Option(descriptor.group(g.ordinal)) | ||
def stringOf(g: SpecGroup): String = group(g).getOrElse("") | ||
def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt) | ||
|
||
extension (inline value: Boolean) | ||
inline def or(inline body: => Unit): Boolean = value || { body ; false } | ||
inline def orElse(inline body: => Unit): Boolean = value || { body ; true } | ||
inline def and(inline body: => Unit): Boolean = value && { body ; true } | ||
inline def but(inline body: => Unit): Boolean = value && { body ; false } | ||
|
||
enum Kind: | ||
case StringXn, HashXn, BooleanXn, CharacterXn, IntegralXn, FloatingPointXn, DateTimeXn, LiteralXn, ErrorXn | ||
import Kind.* | ||
|
||
/** A conversion specifier matched in the argi'th string part, with `argc` arguments to interpolate. | ||
*/ | ||
final class Conversion(val descriptor: Match, val argi: Int, val kind: Kind): | ||
// the descriptor fields | ||
val index: Option[Int] = descriptor.intOf(Index) | ||
val flags: String = descriptor.stringOf(Flags) | ||
val width: Option[Int] = descriptor.intOf(Width) | ||
val precision: Option[Int] = descriptor.group(Precision).map(_.drop(1).toInt) | ||
val op: String = descriptor.stringOf(CC) | ||
|
||
// the conversion char is the head of the op string (but see DateTimeXn) | ||
val cc: Char = | ||
kind match | ||
case ErrorXn => if op.isEmpty then '?' else op(0) | ||
case DateTimeXn => if op.length > 1 then op(1) else '?' | ||
case _ => op(0) | ||
|
||
def isIndexed: Boolean = index.nonEmpty || hasFlag('<') | ||
def isError: Boolean = kind == ErrorXn | ||
def isLiteral: Boolean = kind == LiteralXn | ||
|
||
// descriptor is at index 0 of the part string | ||
def isLeading: Boolean = descriptor.at(Spec) == 0 | ||
|
||
// true if passes. | ||
def verify: Boolean = | ||
// various assertions | ||
def goodies = goodFlags && goodIndex | ||
def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") | ||
def noWidth = width.isEmpty or errorAt(Width)("width not allowed") | ||
def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") | ||
def only_-(msg: String) = | ||
val badFlags = flags.filterNot { case '-' | '<' => true case _ => false } | ||
badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg") | ||
def goodFlags = | ||
val badFlags = flags.filterNot(okFlags.contains) | ||
for f <- badFlags do badFlag(f, s"Illegal flag '$f'") | ||
badFlags.isEmpty | ||
def goodIndex = | ||
if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present") | ||
val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true) | ||
okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") | ||
// begin verify | ||
kind match | ||
case StringXn => goodies | ||
case BooleanXn => goodies | ||
case HashXn => goodies | ||
case CharacterXn => goodies && noPrecision && only_-("c conversion") | ||
case IntegralXn => | ||
def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion") | ||
def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") | ||
goodies && noPrecision && !d_# && !x_comma | ||
case FloatingPointXn => | ||
goodies && (cc match | ||
case 'a' | 'A' => | ||
val badFlags = ",(".filter(hasFlag) | ||
noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) | ||
case _ => true | ||
) | ||
case DateTimeXn => | ||
def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") | ||
def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") | ||
goodies && hasCC && goodCC && noPrecision && only_-("date/time conversions") | ||
case LiteralXn => | ||
op match | ||
case "%" => goodies && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) | ||
case "n" => noFlags && noWidth && noPrecision | ||
case ErrorXn => | ||
errorAt(CC)(s"illegal conversion character '$cc'") | ||
false | ||
end verify | ||
|
||
// is the specifier OK with the given arg | ||
def accepts(arg: Type): Boolean = | ||
kind match | ||
case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean") | ||
case IntegralXn => | ||
arg == BigIntType || !cond(cc) { | ||
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true | ||
} | ||
case _ => true | ||
|
||
// what arg type if any does the conversion accept | ||
def acceptableVariants: List[Type] = | ||
kind match | ||
case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil | ||
case BooleanXn => defn.BooleanType :: defn.NullType :: Nil | ||
case HashXn => defn.AnyType :: Nil | ||
case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil | ||
case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil | ||
case FloatingPointXn => defn.DoubleType :: defn.FloatType :: BigDecimalType :: Nil | ||
case DateTimeXn => defn.LongType :: CalendarType :: DateType :: Nil | ||
case LiteralXn => Nil | ||
case ErrorXn => Nil | ||
|
||
// what flags does the conversion accept? | ||
private def okFlags: String = | ||
kind match | ||
case StringXn => "-#<" | ||
case BooleanXn | HashXn => "-<" | ||
case LiteralXn => "-" | ||
case _ => "-#+ 0,(<" | ||
|
||
def hasFlag(f: Char) = flags.contains(f) | ||
def hasAnyFlag(fs: String) = fs.exists(hasFlag) | ||
|
||
def badFlag(f: Char, msg: String) = | ||
val i = flags.indexOf(f) match { case -1 => 0 case j => j } | ||
errorAt(Flags, i)(msg) | ||
|
||
def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partError(msg, argi, descriptor.offset(g, i), descriptor.end(g)) | ||
def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partWarning(msg, argi, descriptor.offset(g, i), descriptor.end(g)) | ||
|
||
object Conversion: | ||
def apply(m: Match, i: Int): Conversion = | ||
def kindOf(cc: Char) = cc match | ||
case 's' | 'S' => StringXn | ||
case 'h' | 'H' => HashXn | ||
case 'b' | 'B' => BooleanXn | ||
case 'c' | 'C' => CharacterXn | ||
case 'd' | 'o' | | ||
'x' | 'X' => IntegralXn | ||
case 'e' | 'E' | | ||
'f' | | ||
'g' | 'G' | | ||
'a' | 'A' => FloatingPointXn | ||
case 't' | 'T' => DateTimeXn | ||
case '%' | 'n' => LiteralXn | ||
case _ => ErrorXn | ||
end kindOf | ||
m.group(CC) match | ||
case Some(cc) => new Conversion(m, i, kindOf(cc(0))).tap(_.verify) | ||
case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp")) | ||
end apply | ||
val literalHelp = "use %% for literal %, %n for newline" | ||
end Conversion | ||
|
||
var reported = false | ||
|
||
private def partPosAt(index: Int, offset: Int, end: Int) = | ||
val pos = partsElems(index).sourcePos | ||
val bgn = pos.span.start + offset | ||
val fin = if end < 0 then pos.span.end else pos.span.start + end | ||
pos.withSpan(Span(bgn, fin, bgn)) | ||
|
||
extension (r: report.type) | ||
def argError(message: String, index: Int): Unit = r.error(message, args(index).srcPos).tap(_ => reported = true) | ||
def partError(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.error(message, partPosAt(index, offset, end)).tap(_ => reported = true) | ||
def partWarning(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.warning(message, partPosAt(index, offset, end)).tap(_ => reported = true) | ||
end TypedFormatChecker |
39 changes: 39 additions & 0 deletions
39
compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package dotty.tools.dotc | ||
package transform.localopt | ||
|
||
import dotty.tools.dotc.ast.tpd.* | ||
import dotty.tools.dotc.core.Constants.Constant | ||
import dotty.tools.dotc.core.Contexts.* | ||
|
||
object FormatInterpolatorTransform: | ||
|
||
/** For f"${arg}%xpart", check format conversions and return (format, args) | ||
* suitable for String.format(format, args). | ||
*/ | ||
def checked(fun: Tree, args0: Tree)(using Context): (Tree, Tree) = | ||
val (partsExpr, parts) = fun match | ||
case TypeApply(Select(Apply(_, (parts: SeqLiteral) :: Nil), _), _) => | ||
(parts.elems, parts.elems.map { case Literal(Constant(s: String)) => s }) | ||
case _ => | ||
report.error("Expected statically known StringContext", fun.srcPos) | ||
(Nil, Nil) | ||
val (args, elemtpt) = args0 match | ||
case seqlit: SeqLiteral => (seqlit.elems, seqlit.elemtpt) | ||
case _ => | ||
report.error("Expected statically known argument list", args0.srcPos) | ||
(Nil, EmptyTree) | ||
|
||
def literally(s: String) = Literal(Constant(s)) | ||
if parts.lengthIs != args.length + 1 then | ||
val badParts = | ||
if parts.isEmpty then "there are no parts" | ||
else s"too ${if parts.lengthIs > args.length + 1 then "few" else "many"} arguments for interpolated string" | ||
report.error(badParts, fun.srcPos) | ||
(literally(""), args0) | ||
else | ||
val checker = TypedFormatChecker(partsExpr, parts, args) | ||
val (format, formatArgs) = checker.checked | ||
if format.isEmpty then (literally(parts.mkString), args0) | ||
else (literally(format.mkString), SeqLiteral(formatArgs.toList, elemtpt)) | ||
end checked | ||
end FormatInterpolatorTransform |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.