Skip to content

Commit f64e879

Browse files
Merge pull request #8024 from dotty-staging/fix-6709
Fix #6709: Correlate match types and match expression
2 parents a6b47e4 + 60cfa2d commit f64e879

File tree

5 files changed

+293
-135
lines changed

5 files changed

+293
-135
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,55 @@ class Typer extends Namer
12341234
if (tree.isInline) checkInInlineContext("inline match", tree.posd)
12351235
val sel1 = typedExpr(tree.selector)
12361236
val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen
1237-
val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1237+
1238+
/** Extractor for match types hidden behind an AppliedType/MatchAlias */
1239+
object MatchTypeInDisguise {
1240+
def unapply(tp: AppliedType): Option[MatchType] = tp match {
1241+
case AppliedType(tycon: TypeRef, args) =>
1242+
tycon.info match {
1243+
case MatchAlias(alias) =>
1244+
alias.applyIfParameterized(args) match {
1245+
case mt: MatchType => Some(mt)
1246+
case _ => None
1247+
}
1248+
case _ => None
1249+
}
1250+
case _ => None
1251+
}
1252+
}
1253+
1254+
/** Does `tree` has the same shape as the given match type?
1255+
* We only support typed patterns with empty guards, but
1256+
* that could potentially be extended in the future.
1257+
*/
1258+
def isMatchTypeShaped(mt: MatchType): Boolean =
1259+
mt.cases.size == tree.cases.size
1260+
&& sel1.tpe.frozen_<:<(mt.scrutinee)
1261+
&& tree.cases.forall(_.guard.isEmpty)
1262+
&& tree.cases
1263+
.map(cas => untpd.unbind(untpd.unsplice(cas.pat)))
1264+
.zip(mt.cases)
1265+
.forall {
1266+
case (pat: Typed, pt) =>
1267+
// To check that pattern types correspond we need to type
1268+
// check `pat` here and throw away the result.
1269+
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
1270+
val pat1 = typedPattern(pat, selType)(using gadtCtx)
1271+
val Typed(_, tpt) = tpd.unbind(tpd.unsplice(pat1))
1272+
instantiateMatchTypeProto(pat1, pt) match {
1273+
case defn.MatchCase(patternTp, _) => tpt.tpe frozen_=:= patternTp
1274+
case _ => false
1275+
}
1276+
case _ => false
1277+
}
1278+
1279+
val result = pt match {
1280+
case MatchTypeInDisguise(mt) if isMatchTypeShaped(mt) =>
1281+
typedDependentMatchFinish(tree, sel1, selType, tree.cases, mt)
1282+
case _ =>
1283+
typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1284+
}
1285+
12381286
result match {
12391287
case Match(sel, CaseDef(pat, _, _) :: _) =>
12401288
tree.selector.removeAttachment(desugar.CheckIrrefutable) match {
@@ -1250,6 +1298,21 @@ class Typer extends Namer
12501298
result
12511299
}
12521300

1301+
/** Special typing of Match tree when the expected type is a MatchType,
1302+
* and the patterns of the Match tree and the MatchType correspond.
1303+
*/
1304+
def typedDependentMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: MatchType)(using Context): Tree = {
1305+
var caseCtx = ctx
1306+
val cases1 = tree.cases.zip(pt.cases)
1307+
.map { case (cas, tpe) =>
1308+
val case1 = typedCase(cas, sel, wideSelType, tpe)(using caseCtx)
1309+
caseCtx = Nullables.afterPatternContext(sel, case1.pat)
1310+
case1
1311+
}
1312+
.asInstanceOf[List[CaseDef]]
1313+
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt)
1314+
}
1315+
12531316
// Overridden in InlineTyper for inline matches
12541317
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(using Context): Tree = {
12551318
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
@@ -1290,17 +1353,33 @@ class Typer extends Namer
12901353
}
12911354
}
12921355

1356+
/** If the prototype `pt` is the type lambda (when doing a dependent
1357+
* typing of a match), instantiate that type lambda with the pattern
1358+
* variables found in the pattern `pat`.
1359+
*/
1360+
def instantiateMatchTypeProto(pat: Tree, pt: Type)(implicit ctx: Context) = pt match {
1361+
case caseTp: HKTypeLambda =>
1362+
val bindingsSyms = tpd.patVars(pat).reverse
1363+
val bindingsTps = bindingsSyms.collect { case sym if sym.isType => sym.typeRef }
1364+
caseTp.appliedTo(bindingsTps)
1365+
case pt => pt
1366+
}
1367+
12931368
/** Type a case. */
12941369
def typedCase(tree: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(using Context): CaseDef = {
12951370
val originalCtx = ctx
12961371
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
12971372

12981373
def caseRest(pat: Tree)(using Context) = {
1374+
val pt1 = instantiateMatchTypeProto(pat, pt) match {
1375+
case defn.MatchCase(_, bodyPt) => bodyPt
1376+
case pt => pt
1377+
}
12991378
val pat1 = indexPattern(tree).transform(pat)
13001379
val guard1 = typedExpr(tree.guard, defn.BooleanType)
1301-
var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt), pt, ctx.scope.toList)
1302-
if (pt.isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1303-
body1 = body1.ensureConforms(pt)(originalCtx)
1380+
var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1), pt1, ctx.scope.toList)
1381+
if (pt1.isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1382+
body1 = body1.ensureConforms(pt1)(originalCtx)
13041383
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1)
13051384
}
13061385

compiler/test/dotc/pos-test-pickling.blacklist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ matchtype.scala
2626
i7087.scala
2727
i7868.scala
2828
i7872.scala
29+
6709.scala
2930

3031
# Opaque type
3132
i5720.scala

0 commit comments

Comments
 (0)