Skip to content

Fix #2578: (part 2) Make for-generators filter only if prefixed with case. #6448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 15, 2019
113 changes: 69 additions & 44 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,19 @@ object desugar {
*/
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key

/** An attachment for match expressions generated from a PatDef */
val PatDefMatch: Property.Key[Unit] = new Property.Key
/** An attachment for match expressions generated from a PatDef or GenFrom.
* Value of key == one of IrrefutablePatDef, IrrefutableGenFrom
*/
val CheckIrrefutable: Property.Key[MatchCheck] = new Property.StickyKey

/** What static check should be applied to a Match (none, irrefutable, exhaustive) */
class MatchCheck(val n: Int) extends AnyVal
object MatchCheck {
val None = new MatchCheck(0)
val Exhaustive = new MatchCheck(1)
val IrrefutablePatDef = new MatchCheck(2)
val IrrefutableGenFrom = new MatchCheck(3)
}

/** Info of a variable in a pattern: The named tree and its type */
private type VarInfo = (NameTree, Tree)
Expand Down Expand Up @@ -926,6 +937,22 @@ object desugar {
}
}

/** The selector of a match, which depends of the given `checkMode`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that @unchecked is also added to irrefutable patterns really confused me until I understood that annotating the selector only affects exhaustivity checks and not irrefutability checks. Perhaps we could note this in the doc?

Suggested change
/** The selector of a match, which depends of the given `checkMode`.
/** The selector of a match, which depends of the given `checkMode`.
*
* The @unchecked annotation is added whenever `checkMode` is not `Exhaustive` to silence
* unnecessary inexhaustive match warnings.

* @param sel the original selector
* @return if `checkMode` is
* - None : sel @unchecked
* - Exhaustive : sel
* - IrrefutablePatDef,
* IrrefutableGenFrom: sel @unchecked with attachment `CheckIrrefutable -> checkMode`
*/
def makeSelector(sel: Tree, checkMode: MatchCheck)(implicit ctx: Context): Tree =
if (checkMode == MatchCheck.Exhaustive) sel
else {
val sel1 = Annotated(sel, New(ref(defn.UncheckedAnnotType)))
if (checkMode != MatchCheck.None) sel1.pushAttachment(CheckIrrefutable, checkMode)
sel1
}

/** If `pat` is a variable pattern,
*
* val/var/lazy val p = e
Expand Down Expand Up @@ -960,11 +987,6 @@ object desugar {
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
val tupleOptimizable = forallResults(rhs, isMatchingTuple)

def rhsUnchecked = {
val rhs1 = makeAnnotated("scala.unchecked", rhs)
rhs1.pushAttachment(PatDefMatch, ())
rhs1
}
val vars =
if (tupleOptimizable) // include `_`
pat match {
Expand All @@ -977,7 +999,7 @@ object desugar {
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids))
val matchExpr =
if (tupleOptimizable) rhs
else Match(rhsUnchecked, caseDef :: Nil)
else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)
vars match {
case Nil =>
matchExpr
Expand Down Expand Up @@ -1120,20 +1142,16 @@ object desugar {
*
* { cases }
* ==>
* x$1 => (x$1 @unchecked) match { cases }
* x$1 => (x$1 @unchecked?) match { cases }
*
* If `nparams` != 1, expand instead to
*
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked) match { cases }
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked?) match { cases }
*/
def makeCaseLambda(cases: List[CaseDef], nparams: Int = 1, unchecked: Boolean = true)(implicit ctx: Context): Function = {
def makeCaseLambda(cases: List[CaseDef], checkMode: MatchCheck, nparams: Int = 1)(implicit ctx: Context): Function = {
val params = (1 to nparams).toList.map(makeSyntheticParameter(_))
val selector = makeTuple(params.map(p => Ident(p.name)))

if (unchecked)
Function(params, Match(Annotated(selector, New(ref(defn.UncheckedAnnotType))), cases))
else
Function(params, Match(selector, cases))
Function(params, Match(makeSelector(selector, checkMode), cases))
}

/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
Expand Down Expand Up @@ -1262,15 +1280,19 @@ object desugar {
*/
def makeFor(mapName: TermName, flatMapName: TermName, enums: List[Tree], body: Tree): Tree = trace(i"make for ${ForYield(enums, body)}", show = true) {

/** Make a function value pat => body.
* If pat is a var pattern id: T then this gives (id: T) => body
* Otherwise this gives { case pat => body }
/** Let `pat` be `gen`'s pattern. Make a function value `pat => body`.
* If `pat` is a var pattern `id: T` then this gives `(id: T) => body`.
* Otherwise this gives `{ case pat => body }`, where `pat` is checked to be
* irrefutable if `gen`'s checkMode is GenCheckMode.Check.
*/
def makeLambda(pat: Tree, body: Tree): Tree = pat match {
case IdPattern(named, tpt) =>
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
def makeLambda(gen: GenFrom, body: Tree): Tree = gen.pat match {
case IdPattern(named, tpt) if gen.checkMode != GenCheckMode.FilterAlways =>
Function(derivedValDef(gen.pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
case _ =>
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil)
val matchCheckMode =
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
else MatchCheck.None
makeCaseLambda(CaseDef(gen.pat, EmptyTree, body) :: Nil, matchCheckMode)
}

/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
Expand Down Expand Up @@ -1316,7 +1338,7 @@ object desugar {
val cases = List(
CaseDef(pat, EmptyTree, Literal(Constant(true))),
CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))))
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases))
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases, MatchCheck.None))
}

/** Is pattern `pat` irrefutable when matched against `rhs`?
Expand All @@ -1342,41 +1364,47 @@ object desugar {
}
}

def isIrrefutableGenFrom(gen: GenFrom): Boolean =
gen.isInstanceOf[IrrefutableGenFrom] ||
IdPattern.unapply(gen.pat).isDefined ||
isIrrefutable(gen.pat, gen.expr)
def needsNoFilter(gen: GenFrom): Boolean =
if (gen.checkMode == GenCheckMode.FilterAlways) // pattern was prefixed by `case`
false
else (
gen.checkMode != GenCheckMode.FilterNow ||
IdPattern.unapply(gen.pat).isDefined ||
isIrrefutable(gen.pat, gen.expr)
)

/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
* matched against `rhs`.
*/
def rhsSelect(gen: GenFrom, name: TermName) = {
val rhs = if (isIrrefutableGenFrom(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
val rhs = if (needsNoFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
Select(rhs, name)
}

def checkMode(gen: GenFrom) =
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
else MatchCheck.None // refutable paterns were already eliminated in filter step

enums match {
case (gen: GenFrom) :: Nil =>
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _) :: _)) =>
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
case (GenFrom(pat, rhs)) :: (rest @ GenAlias(_, _) :: _) =>
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
val (defpat0, id0) = makeIdPat(pat)
val (defpat0, id0) = makeIdPat(gen.pat)
val (defpats, ids) = (pats map makeIdPat).unzip
val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _))
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, rhs) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = pat :: pats
val vfrom1 = new IrrefutableGenFrom(makeTuple(allpats), rhs1)
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
val allpats = gen.pat :: pats
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test))
val genFrom =
if (isIrrefutableGenFrom(gen)) new IrrefutableGenFrom(gen.pat, filtered)
else GenFrom(gen.pat, filtered)
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case _ =>
EmptyTree //may happen for erroneous input
Expand Down Expand Up @@ -1571,7 +1599,4 @@ object desugar {
collect(tree)
buf.toList
}

private class IrrefutableGenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile)
extends GenFrom(pat, expr)
}
23 changes: 16 additions & 7 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class DoWhile(body: Tree, cond: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class ForYield(enums: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class ForDo(enums: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
case class GenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
case class GenFrom(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit @constructorOnly src: SourceFile) extends Tree
case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree
Expand All @@ -116,6 +116,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
* `Positioned#checkPos` */
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)

/** An enum to control checking or filtering of patterns in GenFrom trees */
class GenCheckMode(val x: Int) extends AnyVal
object GenCheckMode {
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
val Check = new GenCheckMode(1) // check that pattern is irrefutable
val FilterNow = new GenCheckMode(2) // filter out non-matching elements since we are not in -strict
val FilterAlways = new GenCheckMode(3) // filter out non-matching elements since pattern is prefixed by `case`
}

// ----- Modifiers -----------------------------------------------------
/** Mod is intended to record syntactic information about modifiers, it's
* NOT a replacement of FlagSet.
Expand Down Expand Up @@ -525,9 +534,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case tree: ForDo if (enums eq tree.enums) && (body eq tree.body) => tree
case _ => finalize(tree, untpd.ForDo(enums, body)(tree.source))
}
def GenFrom(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.GenFrom(pat, expr)(tree.source))
def GenFrom(tree: Tree)(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit ctx: Context): Tree = tree match {
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (checkMode == tree.checkMode) => tree
case _ => finalize(tree, untpd.GenFrom(pat, expr, checkMode)(tree.source))
}
def GenAlias(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
case tree: GenAlias if (pat eq tree.pat) && (expr eq tree.expr) => tree
Expand Down Expand Up @@ -589,8 +598,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
cpy.ForYield(tree)(transform(enums), transform(expr))
case ForDo(enums, body) =>
cpy.ForDo(tree)(transform(enums), transform(body))
case GenFrom(pat, expr) =>
cpy.GenFrom(tree)(transform(pat), transform(expr))
case GenFrom(pat, expr, checkMode) =>
cpy.GenFrom(tree)(transform(pat), transform(expr), checkMode)
case GenAlias(pat, expr) =>
cpy.GenAlias(tree)(transform(pat), transform(expr))
case ContextBounds(bounds, cxBounds) =>
Expand Down Expand Up @@ -644,7 +653,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
this(this(x, enums), expr)
case ForDo(enums, body) =>
this(this(x, enums), body)
case GenFrom(pat, expr) =>
case GenFrom(pat, expr, _) =>
this(this(x, pat), expr)
case GenAlias(pat, expr) =>
this(this(x, pat), expr)
Expand Down
48 changes: 29 additions & 19 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1725,18 +1725,28 @@ object Parsers {
*/
def enumerator(): Tree =
if (in.token == IF) guard()
else if (in.token == CASE) generator()
else {
val pat = pattern1()
if (in.token == EQUALS) atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, expr()) }
else generatorRest(pat)
else generatorRest(pat, casePat = false)
}

/** Generator ::= Pattern `<-' Expr
/** Generator ::= [‘case’] Pattern `<-' Expr
*/
def generator(): Tree = generatorRest(pattern1())
def generator(): Tree = {
val casePat = if (in.token == CASE) { in.skipCASE(); true } else false
generatorRest(pattern1(), casePat)
}

def generatorRest(pat: Tree): GenFrom =
atSpan(startOffset(pat), accept(LARROW)) { GenFrom(pat, expr()) }
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
atSpan(startOffset(pat), accept(LARROW)) {
val checkMode =
if (casePat) GenCheckMode.FilterAlways
else if (ctx.settings.strict.value) GenCheckMode.Check
else GenCheckMode.FilterNow // filter for now, to keep backwards compat
GenFrom(pat, expr(), checkMode)
}

/** ForExpr ::= `for' (`(' Enumerators `)' | `{' Enumerators `}')
* {nl} [`yield'] Expr
Expand All @@ -1749,16 +1759,20 @@ object Parsers {
else if (in.token == LPAREN) {
val lparenOffset = in.skipToken()
openParens.change(LPAREN, 1)
val pats = patternsOpt()
val pat =
if (in.token == RPAREN || pats.length > 1) {
wrappedEnums = false
accept(RPAREN)
openParens.change(LPAREN, -1)
atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer.
val res =
if (in.token == CASE) enumerators()
else {
val pats = patternsOpt()
val pat =
if (in.token == RPAREN || pats.length > 1) {
wrappedEnums = false
accept(RPAREN)
openParens.change(LPAREN, -1)
atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer.
}
else pats.head
generatorRest(pat, casePat = false) :: enumeratorsRest()
}
else pats.head
val res = generatorRest(pat) :: enumeratorsRest()
if (wrappedEnums) {
accept(RPAREN)
openParens.change(LPAREN, -1)
Expand Down Expand Up @@ -2640,11 +2654,7 @@ object Parsers {
*/
def enumCase(start: Offset, mods: Modifiers): DefTree = {
val mods1 = addMod(mods, atSpan(in.offset)(Mod.Enum())) | Case
accept(CASE)

in.adjustSepRegions(ARROW)
// Scanner thinks it is in a pattern match after seeing the `case`.
// We need to get it out of that mode by telling it we are past the `=>`
in.skipCASE()

atSpan(start, nameStart) {
val id = termIdent()
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ object Scanners {
case _ =>
}

/** Advance beyond a case token without marking the CASE in sepRegions.
* This method should be called to skip beyond CASE tokens that are
* not part of matches, i.e. no ARROW is expected after them.
*/
def skipCASE() = {
assert(token == CASE)
nextToken()
sepRegions = sepRegions.tail
}

/** Produce next token, filling TokenData fields of Scanner.
*/
def nextToken(): Unit = {
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
forText(enums, expr, keywordStr(" yield "))
case ForDo(enums, expr) =>
forText(enums, expr, keywordStr(" do "))
case GenFrom(pat, expr) =>
case GenFrom(pat, expr, checkMode) =>
(Str("case ") provided checkMode == untpd.GenCheckMode.FilterAlways) ~
toText(pat) ~ " <- " ~ toText(expr)
case GenAlias(pat, expr) =>
toText(pat) ~ " = " ~ toText(expr)
Expand Down
Loading