diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index b86efa2bf20a..8f84b8fbf838 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -33,8 +33,19 @@ object desugar { */ val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key - /** An attachment for match expressions generated from a PatDef */ - val PatDefMatch: Property.Key[Unit] = new Property.Key + /** An attachment for match expressions generated from a PatDef or GenFrom. + * Value of key == one of IrrefutablePatDef, IrrefutableGenFrom + */ + val CheckIrrefutable: Property.Key[MatchCheck] = new Property.StickyKey + + /** What static check should be applied to a Match (none, irrefutable, exhaustive) */ + class MatchCheck(val n: Int) extends AnyVal + object MatchCheck { + val None = new MatchCheck(0) + val Exhaustive = new MatchCheck(1) + val IrrefutablePatDef = new MatchCheck(2) + val IrrefutableGenFrom = new MatchCheck(3) + } /** Info of a variable in a pattern: The named tree and its type */ private type VarInfo = (NameTree, Tree) @@ -926,6 +937,22 @@ object desugar { } } + /** The selector of a match, which depends of the given `checkMode`. + * @param sel the original selector + * @return if `checkMode` is + * - None : sel @unchecked + * - Exhaustive : sel + * - IrrefutablePatDef, + * IrrefutableGenFrom: sel @unchecked with attachment `CheckIrrefutable -> checkMode` + */ + def makeSelector(sel: Tree, checkMode: MatchCheck)(implicit ctx: Context): Tree = + if (checkMode == MatchCheck.Exhaustive) sel + else { + val sel1 = Annotated(sel, New(ref(defn.UncheckedAnnotType))) + if (checkMode != MatchCheck.None) sel1.pushAttachment(CheckIrrefutable, checkMode) + sel1 + } + /** If `pat` is a variable pattern, * * val/var/lazy val p = e @@ -960,11 +987,6 @@ object desugar { // - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)` val tupleOptimizable = forallResults(rhs, isMatchingTuple) - def rhsUnchecked = { - val rhs1 = makeAnnotated("scala.unchecked", rhs) - rhs1.pushAttachment(PatDefMatch, ()) - rhs1 - } val vars = if (tupleOptimizable) // include `_` pat match { @@ -977,7 +999,7 @@ object desugar { val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids)) val matchExpr = if (tupleOptimizable) rhs - else Match(rhsUnchecked, caseDef :: Nil) + else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil) vars match { case Nil => matchExpr @@ -1120,20 +1142,16 @@ object desugar { * * { cases } * ==> - * x$1 => (x$1 @unchecked) match { cases } + * x$1 => (x$1 @unchecked?) match { cases } * * If `nparams` != 1, expand instead to * - * (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked) match { cases } + * (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked?) match { cases } */ - def makeCaseLambda(cases: List[CaseDef], nparams: Int = 1, unchecked: Boolean = true)(implicit ctx: Context): Function = { + def makeCaseLambda(cases: List[CaseDef], checkMode: MatchCheck, nparams: Int = 1)(implicit ctx: Context): Function = { val params = (1 to nparams).toList.map(makeSyntheticParameter(_)) val selector = makeTuple(params.map(p => Ident(p.name))) - - if (unchecked) - Function(params, Match(Annotated(selector, New(ref(defn.UncheckedAnnotType))), cases)) - else - Function(params, Match(selector, cases)) + Function(params, Match(makeSelector(selector, checkMode), cases)) } /** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows: @@ -1262,15 +1280,19 @@ object desugar { */ def makeFor(mapName: TermName, flatMapName: TermName, enums: List[Tree], body: Tree): Tree = trace(i"make for ${ForYield(enums, body)}", show = true) { - /** Make a function value pat => body. - * If pat is a var pattern id: T then this gives (id: T) => body - * Otherwise this gives { case pat => body } + /** Let `pat` be `gen`'s pattern. Make a function value `pat => body`. + * If `pat` is a var pattern `id: T` then this gives `(id: T) => body`. + * Otherwise this gives `{ case pat => body }`, where `pat` is checked to be + * irrefutable if `gen`'s checkMode is GenCheckMode.Check. */ - def makeLambda(pat: Tree, body: Tree): Tree = pat match { - case IdPattern(named, tpt) => - Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body) + def makeLambda(gen: GenFrom, body: Tree): Tree = gen.pat match { + case IdPattern(named, tpt) if gen.checkMode != GenCheckMode.FilterAlways => + Function(derivedValDef(gen.pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body) case _ => - makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil) + val matchCheckMode = + if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom + else MatchCheck.None + makeCaseLambda(CaseDef(gen.pat, EmptyTree, body) :: Nil, matchCheckMode) } /** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap @@ -1316,7 +1338,7 @@ object desugar { val cases = List( CaseDef(pat, EmptyTree, Literal(Constant(true))), CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false)))) - Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases)) + Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases, MatchCheck.None)) } /** Is pattern `pat` irrefutable when matched against `rhs`? @@ -1342,41 +1364,47 @@ object desugar { } } - def isIrrefutableGenFrom(gen: GenFrom): Boolean = - gen.isInstanceOf[IrrefutableGenFrom] || - IdPattern.unapply(gen.pat).isDefined || - isIrrefutable(gen.pat, gen.expr) + def needsNoFilter(gen: GenFrom): Boolean = + if (gen.checkMode == GenCheckMode.FilterAlways) // pattern was prefixed by `case` + false + else ( + gen.checkMode != GenCheckMode.FilterNow || + IdPattern.unapply(gen.pat).isDefined || + isIrrefutable(gen.pat, gen.expr) + ) /** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when * matched against `rhs`. */ def rhsSelect(gen: GenFrom, name: TermName) = { - val rhs = if (isIrrefutableGenFrom(gen)) gen.expr else makePatFilter(gen.expr, gen.pat) + val rhs = if (needsNoFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat) Select(rhs, name) } + def checkMode(gen: GenFrom) = + if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom + else MatchCheck.None // refutable paterns were already eliminated in filter step + enums match { case (gen: GenFrom) :: Nil => - Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body)) - case (gen: GenFrom) :: (rest @ (GenFrom(_, _) :: _)) => + Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => val cont = makeFor(mapName, flatMapName, rest, body) - Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont)) - case (GenFrom(pat, rhs)) :: (rest @ GenAlias(_, _) :: _) => + Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) + case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) val pats = valeqs map { case GenAlias(pat, _) => pat } val rhss = valeqs map { case GenAlias(_, rhs) => rhs } - val (defpat0, id0) = makeIdPat(pat) + val (defpat0, id0) = makeIdPat(gen.pat) val (defpats, ids) = (pats map makeIdPat).unzip val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _)) - val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, rhs) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) - val allpats = pat :: pats - val vfrom1 = new IrrefutableGenFrom(makeTuple(allpats), rhs1) + val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) + val allpats = gen.pat :: pats + val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) makeFor(mapName, flatMapName, vfrom1 :: rest1, body) case (gen: GenFrom) :: test :: rest => - val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test)) - val genFrom = - if (isIrrefutableGenFrom(gen)) new IrrefutableGenFrom(gen.pat, filtered) - else GenFrom(gen.pat, filtered) + val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) + val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore) makeFor(mapName, flatMapName, genFrom :: rest, body) case _ => EmptyTree //may happen for erroneous input @@ -1571,7 +1599,4 @@ object desugar { collect(tree) buf.toList } - - private class IrrefutableGenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) - extends GenFrom(pat, expr) } diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 94ca12203539..d8a2bdac7362 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -99,7 +99,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case class DoWhile(body: Tree, cond: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree case class ForYield(enums: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree case class ForDo(enums: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree - case class GenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree + case class GenFrom(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit @constructorOnly src: SourceFile) extends Tree case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree @@ -116,6 +116,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { * `Positioned#checkPos` */ class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr) + /** An enum to control checking or filtering of patterns in GenFrom trees */ + class GenCheckMode(val x: Int) extends AnyVal + object GenCheckMode { + val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before + val Check = new GenCheckMode(1) // check that pattern is irrefutable + val FilterNow = new GenCheckMode(2) // filter out non-matching elements since we are not in -strict + val FilterAlways = new GenCheckMode(3) // filter out non-matching elements since pattern is prefixed by `case` + } + // ----- Modifiers ----------------------------------------------------- /** Mod is intended to record syntactic information about modifiers, it's * NOT a replacement of FlagSet. @@ -525,9 +534,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case tree: ForDo if (enums eq tree.enums) && (body eq tree.body) => tree case _ => finalize(tree, untpd.ForDo(enums, body)(tree.source)) } - def GenFrom(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match { - case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) => tree - case _ => finalize(tree, untpd.GenFrom(pat, expr)(tree.source)) + def GenFrom(tree: Tree)(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit ctx: Context): Tree = tree match { + case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (checkMode == tree.checkMode) => tree + case _ => finalize(tree, untpd.GenFrom(pat, expr, checkMode)(tree.source)) } def GenAlias(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match { case tree: GenAlias if (pat eq tree.pat) && (expr eq tree.expr) => tree @@ -589,8 +598,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { cpy.ForYield(tree)(transform(enums), transform(expr)) case ForDo(enums, body) => cpy.ForDo(tree)(transform(enums), transform(body)) - case GenFrom(pat, expr) => - cpy.GenFrom(tree)(transform(pat), transform(expr)) + case GenFrom(pat, expr, checkMode) => + cpy.GenFrom(tree)(transform(pat), transform(expr), checkMode) case GenAlias(pat, expr) => cpy.GenAlias(tree)(transform(pat), transform(expr)) case ContextBounds(bounds, cxBounds) => @@ -644,7 +653,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this(this(x, enums), expr) case ForDo(enums, body) => this(this(x, enums), body) - case GenFrom(pat, expr) => + case GenFrom(pat, expr, _) => this(this(x, pat), expr) case GenAlias(pat, expr) => this(this(x, pat), expr) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 9d16d4077816..622397a20928 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1725,18 +1725,28 @@ object Parsers { */ def enumerator(): Tree = if (in.token == IF) guard() + else if (in.token == CASE) generator() else { val pat = pattern1() if (in.token == EQUALS) atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, expr()) } - else generatorRest(pat) + else generatorRest(pat, casePat = false) } - /** Generator ::= Pattern `<-' Expr + /** Generator ::= [‘case’] Pattern `<-' Expr */ - def generator(): Tree = generatorRest(pattern1()) + def generator(): Tree = { + val casePat = if (in.token == CASE) { in.skipCASE(); true } else false + generatorRest(pattern1(), casePat) + } - def generatorRest(pat: Tree): GenFrom = - atSpan(startOffset(pat), accept(LARROW)) { GenFrom(pat, expr()) } + def generatorRest(pat: Tree, casePat: Boolean): GenFrom = + atSpan(startOffset(pat), accept(LARROW)) { + val checkMode = + if (casePat) GenCheckMode.FilterAlways + else if (ctx.settings.strict.value) GenCheckMode.Check + else GenCheckMode.FilterNow // filter for now, to keep backwards compat + GenFrom(pat, expr(), checkMode) + } /** ForExpr ::= `for' (`(' Enumerators `)' | `{' Enumerators `}') * {nl} [`yield'] Expr @@ -1749,16 +1759,20 @@ object Parsers { else if (in.token == LPAREN) { val lparenOffset = in.skipToken() openParens.change(LPAREN, 1) - val pats = patternsOpt() - val pat = - if (in.token == RPAREN || pats.length > 1) { - wrappedEnums = false - accept(RPAREN) - openParens.change(LPAREN, -1) - atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer. + val res = + if (in.token == CASE) enumerators() + else { + val pats = patternsOpt() + val pat = + if (in.token == RPAREN || pats.length > 1) { + wrappedEnums = false + accept(RPAREN) + openParens.change(LPAREN, -1) + atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer. + } + else pats.head + generatorRest(pat, casePat = false) :: enumeratorsRest() } - else pats.head - val res = generatorRest(pat) :: enumeratorsRest() if (wrappedEnums) { accept(RPAREN) openParens.change(LPAREN, -1) @@ -2640,11 +2654,7 @@ object Parsers { */ def enumCase(start: Offset, mods: Modifiers): DefTree = { val mods1 = addMod(mods, atSpan(in.offset)(Mod.Enum())) | Case - accept(CASE) - - in.adjustSepRegions(ARROW) - // Scanner thinks it is in a pattern match after seeing the `case`. - // We need to get it out of that mode by telling it we are past the `=>` + in.skipCASE() atSpan(start, nameStart) { val id = termIdent() diff --git a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala index 895bf25fb03d..af0fb9c46190 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala @@ -351,6 +351,16 @@ object Scanners { case _ => } + /** Advance beyond a case token without marking the CASE in sepRegions. + * This method should be called to skip beyond CASE tokens that are + * not part of matches, i.e. no ARROW is expected after them. + */ + def skipCASE() = { + assert(token == CASE) + nextToken() + sepRegions = sepRegions.tail + } + /** Produce next token, filling TokenData fields of Scanner. */ def nextToken(): Unit = { diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index c6ca2d857c79..1372b2588cbe 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -570,7 +570,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { forText(enums, expr, keywordStr(" yield ")) case ForDo(enums, expr) => forText(enums, expr, keywordStr(" do ")) - case GenFrom(pat, expr) => + case GenFrom(pat, expr, checkMode) => + (Str("case ") provided checkMode == untpd.GenCheckMode.FilterAlways) ~ toText(pat) ~ " <- " ~ toText(expr) case GenAlias(pat, expr) => toText(pat) ~ " = " ~ toText(expr) diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index 3c4cd6194de4..b07df0230df7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -602,41 +602,48 @@ trait Checking { * This means `pat` is either marked @unchecked or `pt` conforms to the * pattern's type. If pattern is an UnApply, do the check recursively. */ - def checkIrrefutable(pat: Tree, pt: Type)(implicit ctx: Context): Boolean = { - patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt") + def checkIrrefutable(pat: Tree, pt: Type, isPatDef: Boolean)(implicit ctx: Context): Boolean = { def fail(pat: Tree, pt: Type): Boolean = { + var reportedPt = pt.dropAnnot(defn.UncheckedAnnot) + if (!pat.tpe.isSingleton) reportedPt = reportedPt.widen + val problem = if (pat.tpe <:< reportedPt) "is more specialized than" else "does not match" + val fix = if (isPatDef) "`: @unchecked` after" else "`case ` before" ctx.errorOrMigrationWarning( - ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type ${pt.dropAnnot(defn.UncheckedAnnot)} + ex"""pattern's type ${pat.tpe} $problem the right hand side expression's type $reportedPt | - |If the narrowing is intentional, this can be communicated by writing `: @unchecked` after the full pattern.${err.rewriteNotice}""", + |If the narrowing is intentional, this can be communicated by writing $fix the full pattern.${err.rewriteNotice}""", pat.sourcePos) false } def check(pat: Tree, pt: Type): Boolean = (pt <:< pat.tpe) || fail(pat, pt) - !ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR - pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || { - pat match { - case Bind(_, pat1) => - checkIrrefutable(pat1, pt) - case UnApply(fn, _, pats) => - check(pat, pt) && - (isIrrefutableUnapply(fn) || fail(pat, pt)) && { - val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.sourcePos) - pats.corresponds(argPts)(checkIrrefutable) - } - case Alternative(pats) => - pats.forall(checkIrrefutable(_, pt)) - case Typed(arg, tpt) => - check(pat, pt) && checkIrrefutable(arg, pt) - case Ident(nme.WILDCARD) => - true - case _ => - check(pat, pt) + def recur(pat: Tree, pt: Type): Boolean = + !ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR + pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || { + patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt") + pat match { + case Bind(_, pat1) => + recur(pat1, pt) + case UnApply(fn, _, pats) => + check(pat, pt) && + (isIrrefutableUnapply(fn) || fail(pat, pt)) && { + val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.sourcePos) + pats.corresponds(argPts)(recur) + } + case Alternative(pats) => + pats.forall(recur(_, pt)) + case Typed(arg, tpt) => + check(pat, pt) && recur(arg, pt) + case Ident(nme.WILDCARD) => + true + case _ => + check(pat, pt) + } } - } + + recur(pat, pt) } /** Check that `path` is a legal prefix for an import or export clause */ diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 38cfa7da2d11..5b7d758da246 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1029,8 +1029,10 @@ class Typer extends Namer } else { val (protoFormals, _) = decomposeProtoFunction(pt, 1) - val unchecked = pt.isRef(defn.PartialFunctionClass) - typed(desugar.makeCaseLambda(tree.cases, protoFormals.length, unchecked).withSpan(tree.span), pt) + val checkMode = + if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None + else desugar.MatchCheck.Exhaustive + typed(desugar.makeCaseLambda(tree.cases, checkMode, protoFormals.length).withSpan(tree.span), pt) } case _ => if (tree.isInline) checkInInlineContext("inline match", tree.posd) @@ -1038,10 +1040,15 @@ class Typer extends Namer val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt) result match { - case Match(sel, CaseDef(pat, _, _) :: _) - if (tree.selector.removeAttachment(desugar.PatDefMatch).isDefined) => - if (!checkIrrefutable(pat, sel.tpe) && ctx.scala2Mode) - patch(Span(pat.span.end), ": @unchecked") + case Match(sel, CaseDef(pat, _, _) :: _) => + tree.selector.removeAttachment(desugar.CheckIrrefutable) match { + case Some(checkMode) => + val isPatDef = checkMode == desugar.MatchCheck.IrrefutablePatDef + if (!checkIrrefutable(pat, sel.tpe, isPatDef) && ctx.settings.migration.value) + if (isPatDef) patch(Span(pat.span.end), ": @unchecked") + else patch(Span(pat.span.start), "case ") + case _ => + } case _ => } result diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 9ab3cb81be24..c7a18e8fb9b5 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -194,6 +194,7 @@ class CompilationTests extends ParallelTesting { compileFilesInDir("tests/run-custom-args/Yretain-trees", defaultOptions and "-Yretain-trees"), compileFile("tests/run-custom-args/tuple-cons.scala", allowDeepSubtypes), compileFile("tests/run-custom-args/i5256.scala", allowDeepSubtypes), + compileFile("tests/run-custom-args/fors.scala", defaultOptions and "-strict"), compileFile("tests/run-custom-args/no-useless-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"), compileFilesInDir("tests/run", defaultOptions) ).checkRuns() diff --git a/compiler/test/dotty/tools/dotc/parsing/parsePackage.scala b/compiler/test/dotty/tools/dotc/parsing/parsePackage.scala index bdf0304b8fb4..a72016e82bfe 100644 --- a/compiler/test/dotty/tools/dotc/parsing/parsePackage.scala +++ b/compiler/test/dotty/tools/dotc/parsing/parsePackage.scala @@ -49,8 +49,8 @@ object parsePackage extends ParserTest { ForYield(enums map transform, transform(expr)) case ForDo(enums, expr) => ForDo(enums map transform, transform(expr)) - case GenFrom(pat, expr) => - GenFrom(transform(pat), transform(expr)) + case GenFrom(pat, expr, filtering) => + GenFrom(transform(pat), transform(expr), filtering) case GenAlias(pat, expr) => GenAlias(transform(pat), transform(expr)) case PatDef(mods, pats, tpt, expr) => diff --git a/docs/docs/internals/syntax.md b/docs/docs/internals/syntax.md index 7c071a2c9169..9594b63ba5ab 100644 --- a/docs/docs/internals/syntax.md +++ b/docs/docs/internals/syntax.md @@ -245,7 +245,7 @@ Enumerators ::= Generator {semi Enumerator | Guard} Enumerator ::= Generator | Guard | Pattern1 ‘=’ Expr GenAlias(pat, expr) -Generator ::= Pattern1 ‘<-’ Expr GenFrom(pat, expr) +Generator ::= [‘case’] Pattern1 ‘<-’ Expr GenFrom(pat, expr) Guard ::= ‘if’ PostfixExpr CaseClauses ::= CaseClause { CaseClause } Match(EmptyTree, cases) diff --git a/docs/docs/reference/changed-features/pattern-bindings.md b/docs/docs/reference/changed-features/pattern-bindings.md new file mode 100644 index 000000000000..6bf0aeb3379a --- /dev/null +++ b/docs/docs/reference/changed-features/pattern-bindings.md @@ -0,0 +1,64 @@ +--- +layout: doc-page +title: "Pattern Bindings" +--- + +In Scala 2, pattern bindings in `val` definitions and `for` expressions are +loosely typed. Potentially failing matches are still accepted at compile-time, +but may influence the program's runtime behavior. +From Scala 3.1 on, type checking rules will be tightened so that errors are reported at compile-time instead. + +## Bindings in Pattern Definitions + +```scala + val xs: List[Any] = List(1, 2, 3) + val (x: String) :: _ = xs // error: pattern's type String is more specialized + // than the right hand side expression's type Any +``` +This code gives a compile-time error in Scala 3.1 (and also in Scala 3.0 under the `-strict` setting) whereas it will fail at runtime with a `ClassCastException` in Scala 2. In Scala 3.1, a pattern binding is only allowed if the pattern is _irrefutable_, that is, if the right-hand side's type conforms to the pattern's type. For instance, the following is OK: +```scala + val pair = (1, true) + val (x, y) = pair +``` +Sometimes one wants to decompose data anyway, even though the pattern is refutable. For instance, if at some point one knows that a list `elems` is non-empty one might +want to decompose it like this: +```scala + val first :: rest = elems // error +``` +This works in Scala 2. In fact it is a typical use case for Scala 2's rules. But in Scala 3.1 it will give a type error. One can avoid the error by marking the pattern with an @unchecked annotation: +```scala + val first :: rest : @unchecked = elems // OK +``` +This will make the compiler accept the pattern binding. It might give an error at runtime instead, if the underlying assumption that `elems` can never be empty is wrong. + +## Pattern Bindings in For Expressions + +Analogous changes apply to patterns in `for` expressions. For instance: + +```scala + val elems: List[Any] = List((1, 2), "hello", (3, 4)) + for ((x, y) <- elems) yield (y, x) // error: pattern's type (Any, Any) is more specialized + // than the right hand side expression's type Any +``` +This code gives a compile-time error in Scala 3.1 whereas in Scala 2 the list `elems` +is filtered to retain only the elements of tuple type that match the pattern `(x, y)`. +The filtering functionality can be obtained in Scala 3 by prefixing the pattern with `case`: +```scala + for (case (x, y) <- elems) yield (y, x) // returns List((2, 1), (4, 3)) +``` + +## Syntax Changes + +There are two syntax changes relative to Scala 2: First, pattern definitions can carry ascriptions such as `: @unchecked`. Second, generators in for expressions may be prefixed with `case`. +``` + PatDef ::= ids [‘:’ Type] ‘=’ Expr + | Pattern2 [‘:’ Type | Ascription] ‘=’ Expr + Generator ::= [‘case’] Pattern1 ‘<-’ Expr +``` + +## Migration + +The new syntax is supported in Dotty and Scala 3.0. However, to enable smooth cross compilation between Scala 2 and Scala 3, the changed behavior and additional type checks are only enabled under the `-strict` setting. They will be enabled by default in version 3.1 of the language. + + + diff --git a/docs/sidebar.yml b/docs/sidebar.yml index 5b768b37dc26..9d6b4a33df06 100644 --- a/docs/sidebar.yml +++ b/docs/sidebar.yml @@ -107,7 +107,9 @@ sidebar: url: docs/reference/changed-features/overload-resolution.html - title: Vararg Patterns url: docs/reference/changed-features/vararg-patterns.html - - title: Pattern matching + - title: Pattern Bindings + url: docs/reference/changed-features/pattern-bindings.html + - title: Pattern Matching url: docs/reference/changed-features/pattern-matching.html - title: Eta Expansion url: docs/reference/changed-features/eta-expansion.html diff --git a/tests/neg-strict/filtering-fors.scala b/tests/neg-strict/filtering-fors.scala new file mode 100644 index 000000000000..784e354c53ff --- /dev/null +++ b/tests/neg-strict/filtering-fors.scala @@ -0,0 +1,32 @@ +object Test { + + val xs: List[Any] = ??? + + for (x <- xs) do () // OK + for (x: Any <- xs) do () // OK + + for (x: String <- xs) do () // error + for ((x: String) <- xs) do () // error + for (y@ (x: String) <- xs) do () // error + for ((x, y) <- xs) do () // error + + for ((x: String) <- xs if x.isEmpty) do () // error + for ((x: String) <- xs; y = x) do () // error + for ((x: String) <- xs; (y, z) <- xs) do () // error // error + for (case (x: String) <- xs; (y, z) <- xs) do () // error + for ((x: String) <- xs; case (y, z) <- xs) do () // error + + val pairs: List[Any] = List((1, 2), "hello", (3, 4)) + for ((x, y) <- pairs) yield (y, x) // error + + for (case x: String <- xs) do () // OK + for (case (x: String) <- xs) do () // OK + for (case y@ (x: String) <- xs) do () // OK + for (case (x, y) <- xs) do () // OK + + for (case (x: String) <- xs if x.isEmpty) do () // OK + for (case (x: String) <- xs; y = x) do () // OK + for (case (x: String) <- xs; case (y, z) <- xs) do () // OK + + for (case (x, y) <- pairs) yield (y, x) // OK +} \ No newline at end of file diff --git a/tests/neg-strict/unchecked-patterns.scala b/tests/neg-strict/unchecked-patterns.scala index d6a8cb70cc2a..1adcac67e420 100644 --- a/tests/neg-strict/unchecked-patterns.scala +++ b/tests/neg-strict/unchecked-patterns.scala @@ -9,6 +9,8 @@ object Test { val (_: Int | _: Any) = ??? : Any // error + val 1 = 2 // error + object Positive { def unapply(i: Int): Option[Int] = Some(i).filter(_ > 0) } object Always1 { def unapply(i: Int): Some[Int] = Some(i) } object Pair { def unapply(t: (Int, Int)): t.type = t } diff --git a/tests/neg/zipped.scala b/tests/neg/zipped.scala new file mode 100644 index 000000000000..feef1f824bf0 --- /dev/null +++ b/tests/neg/zipped.scala @@ -0,0 +1,38 @@ +// This test shows some un-intuitive behavior of the `zipped` method. +object Test { + val xs: List[Int] = ??? + + // 1. This works, since withFilter is not defined on Tuple3zipped. Instead, + // an implicit conversion from Tuple3zipped to Traversable[(Int, Int, Int)] is inserted. + // The subsequent map operation has the right type for this Traversable. + (xs, xs, xs).zipped + .withFilter( (x: (Int, Int, Int)) => x match { case (x, y, z) => true } ) // OK + .map( (x: (Int, Int, Int)) => x match { case (x, y, z) => x + y + z }) // OK + + + // 2. This works as well, because of auto untupling i.e. `case` is inserted. + // But it does not work in Scala2. + (xs, xs, xs).zipped + .withFilter( (x: (Int, Int, Int)) => x match { case (x, y, z) => true } ) // OK + .map( (x: Int, y: Int, z: Int) => x + y + z ) // OK + // works, because of auto untupling i.e. `case` is inserted + // does not work in Scala2 + + // 3. Now, without withFilter, it's the opposite, we need the 3 parameter map. + (xs, xs, xs).zipped + .map( (x: Int, y: Int, z: Int) => x + y + z ) // OK + + // 4. The single parameter map does not work. + (xs, xs, xs).zipped + .map( (x: (Int, Int, Int)) => x match { case (x, y, z) => x + y + z }) // error + + // 5. If we leave out the parameter type, we get a "Wrong number of parameters" error instead + (xs, xs, xs).zipped + .map( x => x match { case (x, y, z) => x + y + z }) // error + + // This means that the following works in Dotty in normal mode, since a `withFilter` + // is inserted. But it does no work under -strict. And it will not work in Scala 3.1. + // The reason is that without -strict, the code below is mapped to (1), but with -strict + // it is mapped to (5). + for ((x, y, z) <- (xs, xs, xs).zipped) yield x + y + z +} \ No newline at end of file diff --git a/tests/pos/derives-obj.scala b/tests/pos/derives-obj.scala new file mode 100644 index 000000000000..269d3d36452c --- /dev/null +++ b/tests/pos/derives-obj.scala @@ -0,0 +1,4 @@ +class C[T] +object C { def derived[T]: C[T] = ??? } + +object X extends C[X.type] derives C diff --git a/tests/pos/multi-given.scala b/tests/pos/multi-given.scala new file mode 100644 index 000000000000..0f7b2523b7a4 --- /dev/null +++ b/tests/pos/multi-given.scala @@ -0,0 +1,6 @@ +trait A +trait B +trait C + +def fancy given (a: A, b: B, c: C) = "Fancy!" +def foo(implicit a: A, b: B, c: C) = "foo" diff --git a/tests/run-custom-args/fors.check b/tests/run-custom-args/fors.check new file mode 100644 index 000000000000..a8c7dfd6c4bd --- /dev/null +++ b/tests/run-custom-args/fors.check @@ -0,0 +1,46 @@ + +testOld +1 2 3 +2 +2 +3 +1 2 3 +1 2 3 +0 1 2 3 4 5 6 7 8 9 +0 2 4 6 8 +0 2 4 6 8 +a b c +b c +b c + +testNew +3 +1 2 3 +1 2 3 +0 1 2 3 4 5 6 7 8 9 +0 2 4 6 8 +0 2 4 6 8 +0 2 4 6 8 +0 2 4 6 8 +0 2 4 6 8 +0 2 4 6 8 +0 2 4 6 8 +a b c + +testFiltering +hello world +hello world +hello world +1~2 3~4 +(empty) +hello world +hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4 +(2,1) (4,3) +hello world +hello world +hello world +1~2 3~4 +(empty) +hello world +hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4 +(2,1) (4,3) diff --git a/tests/run-custom-args/fors.scala b/tests/run-custom-args/fors.scala new file mode 100644 index 000000000000..2baa78cbc369 --- /dev/null +++ b/tests/run-custom-args/fors.scala @@ -0,0 +1,117 @@ +//############################################################################ +// for-comprehensions (old and new syntax) +//############################################################################ + +//############################################################################ + +object Test extends dotty.runtime.LegacyApp { + val xs = List(1, 2, 3) + val ys = List(Symbol("a"), Symbol("b"), Symbol("c")) + + def it = 0 until 10 + + val ar = "abc".toCharArray + + /////////////////// old syntax /////////////////// + + def testOld(): Unit = { + println("\ntestOld") + + // lists + for (x <- xs) print(x + " "); println() + for (x <- xs; + if x % 2 == 0) print(x + " "); println() + for {x <- xs + if x % 2 == 0} print(x + " "); println() + var n = 0 + for (_ <- xs) n += 1; println(n) + for ((x, y) <- xs zip ys) print(x + " "); println() + for (p @ (x, y) <- xs zip ys) print(p._1 + " "); println() + + // iterators + for (x <- it) print(x + " "); println() + for (x <- it; + if x % 2 == 0) print(x + " "); println() + for {x <- it + if x % 2 == 0} print(x + " "); println() + + // arrays + for (x <- ar) print(x + " "); println() + for (x <- ar; + if x.toInt > 97) print(x + " "); println() + for {x <- ar + if x.toInt > 97} print(x + " "); println() + + } + + /////////////////// new syntax /////////////////// + + def testNew(): Unit = { + println("\ntestNew") + + // lists + var n = 0 + for (_ <- xs) n += 1; println(n) + for ((x, y) <- xs zip ys) print(x + " "); println() + for (p @ (x, y) <- xs zip ys) print(p._1 + " "); println() + + // iterators + for (x <- it) print(x + " "); println() + for (x <- it if x % 2 == 0) print(x + " "); println() + for (x <- it; if x % 2 == 0) print(x + " "); println() + for (x <- it; + if x % 2 == 0) print(x + " "); println() + for (x <- it + if x % 2 == 0) print(x + " "); println() + for {x <- it + if x % 2 == 0} print(x + " "); println() + for (x <- it; + y = 2 + if x % y == 0) print(x + " "); println() + for {x <- it + y = 2 + if x % y == 0} print(x + " "); println() + + // arrays + for (x <- ar) print(x + " "); println() + + } + + /////////////////// filtering with case /////////////////// + + def testFiltering(): Unit = { + println("\ntestFiltering") + + val xs: List[Any] = List((1, 2), "hello", (3, 4), "", "world") + + for (case x: String <- xs) do print(s"$x "); println() + for (case (x: String) <- xs) do print(s"$x "); println() + for (case y@ (x: String) <- xs) do print(s"$y "); println() + + for (case (x, y) <- xs) do print(s"$x~$y "); println() + + for (case (x: String) <- xs if x.isEmpty) do print("(empty)"); println() + for (case (x: String) <- xs; y = x) do print(s"$y "); println() + for (case (x: String) <- xs; case (y, z) <- xs) do print(s"$x/$y~$z "); println() + + for (case (x, y) <- xs) do print(s"${(y, x)} "); println() + + for case x: String <- xs do print(s"$x "); println() + for case (x: String) <- xs do print(s"$x "); println() + for case y@ (x: String) <- xs do print(s"$y "); println() + + for case (x, y) <- xs do print(s"$x~$y "); println() + + for case (x: String) <- xs if x.isEmpty do print("(empty)"); println() + for case (x: String) <- xs; y = x do print(s"$y "); println() + for case (x: String) <- xs; case (y, z) <- xs do print(s"$x/$y~$z "); println() + + for case (x, y) <- xs do print(s"${(y, x)} "); println() + } + + //////////////////////////////////////////////////// + + testOld() + testNew() + testFiltering() +}