@@ -1234,7 +1234,55 @@ class Typer extends Namer
1234
1234
if (tree.isInline) checkInInlineContext(" inline match" , tree.posd)
1235
1235
val sel1 = typedExpr(tree.selector)
1236
1236
val selType = fullyDefinedType(sel1.tpe, " pattern selector" , tree.span).widen
1237
- val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1237
+
1238
+ /** Extractor for match types hidden behind an AppliedType/MatchAlias */
1239
+ object MatchTypeInDisguise {
1240
+ def unapply (tp : AppliedType ): Option [MatchType ] = tp match {
1241
+ case AppliedType (tycon : TypeRef , args) =>
1242
+ tycon.info match {
1243
+ case MatchAlias (alias) =>
1244
+ alias.applyIfParameterized(args) match {
1245
+ case mt : MatchType => Some (mt)
1246
+ case _ => None
1247
+ }
1248
+ case _ => None
1249
+ }
1250
+ case _ => None
1251
+ }
1252
+ }
1253
+
1254
+ /** Does `tree` has the same shape as the given match type?
1255
+ * We only support typed patterns with empty guards, but
1256
+ * that could potentially be extended in the future.
1257
+ */
1258
+ def isMatchTypeShaped (mt : MatchType ): Boolean =
1259
+ mt.cases.size == tree.cases.size
1260
+ && sel1.tpe.frozen_<:< (mt.scrutinee)
1261
+ && tree.cases.forall(_.guard.isEmpty)
1262
+ && tree.cases
1263
+ .map(cas => untpd.unbind(untpd.unsplice(cas.pat)))
1264
+ .zip(mt.cases)
1265
+ .forall {
1266
+ case (pat : Typed , pt) =>
1267
+ // To check that pattern types correspond we need to type
1268
+ // check `pat` here and throw away the result.
1269
+ val gadtCtx : Context = ctx.fresh.setFreshGADTBounds
1270
+ val pat1 = typedPattern(pat, selType)(using gadtCtx)
1271
+ val Typed (_, tpt) = tpd.unbind(tpd.unsplice(pat1))
1272
+ instantiateMatchTypeProto(pat1, pt) match {
1273
+ case defn.MatchCase (patternTp, _) => tpt.tpe frozen_=:= patternTp
1274
+ case _ => false
1275
+ }
1276
+ case _ => false
1277
+ }
1278
+
1279
+ val result = pt match {
1280
+ case MatchTypeInDisguise (mt) if isMatchTypeShaped(mt) =>
1281
+ typedDependentMatchFinish(tree, sel1, selType, tree.cases, mt)
1282
+ case _ =>
1283
+ typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1284
+ }
1285
+
1238
1286
result match {
1239
1287
case Match (sel, CaseDef (pat, _, _) :: _) =>
1240
1288
tree.selector.removeAttachment(desugar.CheckIrrefutable ) match {
@@ -1250,6 +1298,21 @@ class Typer extends Namer
1250
1298
result
1251
1299
}
1252
1300
1301
+ /** Special typing of Match tree when the expected type is a MatchType,
1302
+ * and the patterns of the Match tree and the MatchType correspond.
1303
+ */
1304
+ def typedDependentMatchFinish (tree : untpd.Match , sel : Tree , wideSelType : Type , cases : List [untpd.CaseDef ], pt : MatchType )(using Context ): Tree = {
1305
+ var caseCtx = ctx
1306
+ val cases1 = tree.cases.zip(pt.cases)
1307
+ .map { case (cas, tpe) =>
1308
+ val case1 = typedCase(cas, sel, wideSelType, tpe)(using caseCtx)
1309
+ caseCtx = Nullables .afterPatternContext(sel, case1.pat)
1310
+ case1
1311
+ }
1312
+ .asInstanceOf [List [CaseDef ]]
1313
+ assignType(cpy.Match (tree)(sel, cases1), sel, cases1).cast(pt)
1314
+ }
1315
+
1253
1316
// Overridden in InlineTyper for inline matches
1254
1317
def typedMatchFinish (tree : untpd.Match , sel : Tree , wideSelType : Type , cases : List [untpd.CaseDef ], pt : Type )(using Context ): Tree = {
1255
1318
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
@@ -1290,17 +1353,33 @@ class Typer extends Namer
1290
1353
}
1291
1354
}
1292
1355
1356
+ /** If the prototype `pt` is the type lambda (when doing a dependent
1357
+ * typing of a match), instantiate that type lambda with the pattern
1358
+ * variables found in the pattern `pat`.
1359
+ */
1360
+ def instantiateMatchTypeProto (pat : Tree , pt : Type )(implicit ctx : Context ) = pt match {
1361
+ case caseTp : HKTypeLambda =>
1362
+ val bindingsSyms = tpd.patVars(pat).reverse
1363
+ val bindingsTps = bindingsSyms.collect { case sym if sym.isType => sym.typeRef }
1364
+ caseTp.appliedTo(bindingsTps)
1365
+ case pt => pt
1366
+ }
1367
+
1293
1368
/** Type a case. */
1294
1369
def typedCase (tree : untpd.CaseDef , sel : Tree , wideSelType : Type , pt : Type )(using Context ): CaseDef = {
1295
1370
val originalCtx = ctx
1296
1371
val gadtCtx : Context = ctx.fresh.setFreshGADTBounds
1297
1372
1298
1373
def caseRest (pat : Tree )(using Context ) = {
1374
+ val pt1 = instantiateMatchTypeProto(pat, pt) match {
1375
+ case defn.MatchCase (_, bodyPt) => bodyPt
1376
+ case pt => pt
1377
+ }
1299
1378
val pat1 = indexPattern(tree).transform(pat)
1300
1379
val guard1 = typedExpr(tree.guard, defn.BooleanType )
1301
- var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt ), pt , ctx.scope.toList)
1302
- if (pt .isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1303
- body1 = body1.ensureConforms(pt )(originalCtx)
1380
+ var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1 ), pt1 , ctx.scope.toList)
1381
+ if (pt1 .isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1382
+ body1 = body1.ensureConforms(pt1 )(originalCtx)
1304
1383
assignType(cpy.CaseDef (tree)(pat1, guard1, body1), pat1, body1)
1305
1384
}
1306
1385
0 commit comments