Skip to content

Commit

Permalink
SIP-56: Better foundations for match types (#18262)
Browse files Browse the repository at this point in the history
Implementation of SIP-56: Proper Specification for Match Types
https://docs.scala-lang.org/sips/match-types-spec.html
  • Loading branch information
sjrd authored Dec 18, 2023
2 parents 8a3fc7a + c3b9d9b commit 881e945
Show file tree
Hide file tree
Showing 60 changed files with 1,670 additions and 290 deletions.
20 changes: 14 additions & 6 deletions compiler/src/dotty/tools/dotc/core/MatchTypeTrace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ object MatchTypeTrace:

private enum TraceEntry:
case TryReduce(scrut: Type)
case Stuck(scrut: Type, stuckCase: Type, otherCases: List[Type])
case NoInstance(scrut: Type, stuckCase: Type, fails: List[(Name, TypeBounds)])
case Stuck(scrut: Type, stuckCase: MatchTypeCaseSpec, otherCases: List[MatchTypeCaseSpec])
case NoInstance(scrut: Type, stuckCase: MatchTypeCaseSpec, fails: List[(Name, TypeBounds)])
case EmptyScrutinee(scrut: Type)
import TraceEntry.*

Expand Down Expand Up @@ -54,10 +54,10 @@ object MatchTypeTrace:
* not disjoint from it either, which means that the remaining cases `otherCases`
* cannot be visited. Only the first failure is recorded.
*/
def stuck(scrut: Type, stuckCase: Type, otherCases: List[Type])(using Context) =
def stuck(scrut: Type, stuckCase: MatchTypeCaseSpec, otherCases: List[MatchTypeCaseSpec])(using Context) =
matchTypeFail(Stuck(scrut, stuckCase, otherCases))

def noInstance(scrut: Type, stuckCase: Type, fails: List[(Name, TypeBounds)])(using Context) =
def noInstance(scrut: Type, stuckCase: MatchTypeCaseSpec, fails: List[(Name, TypeBounds)])(using Context) =
matchTypeFail(NoInstance(scrut, stuckCase, fails))

/** Record a failure that scrutinee `scrut` is provably empty.
Expand All @@ -80,13 +80,16 @@ object MatchTypeTrace:
case _ =>
op

def caseText(spec: MatchTypeCaseSpec)(using Context): String =
caseText(spec.origMatchCase)

def caseText(tp: Type)(using Context): String = tp match
case tp: HKTypeLambda => caseText(tp.resultType)
case defn.MatchCase(any, body) if any eq defn.AnyType => i"case _ => $body"
case defn.MatchCase(pat, body) => i"case $pat => $body"
case _ => i"case $tp"

private def casesText(cases: List[Type])(using Context) =
private def casesText(cases: List[MatchTypeCaseSpec])(using Context) =
i"${cases.map(caseText)}%\n %"

private def explainEntry(entry: TraceEntry)(using Context): String = entry match
Expand Down Expand Up @@ -116,10 +119,15 @@ object MatchTypeTrace:
| ${fails.map((name, bounds) => i"$name$bounds")}%\n %"""

/** The failure message when the scrutinee `scrut` does not match any case in `cases`. */
def noMatchesText(scrut: Type, cases: List[Type])(using Context): String =
def noMatchesText(scrut: Type, cases: List[MatchTypeCaseSpec])(using Context): String =
i"""failed since selector $scrut
|matches none of the cases
|
| ${casesText(cases)}"""

def illegalPatternText(scrut: Type, cas: MatchTypeCaseSpec.LegacyPatMat)(using Context): String =
i"""The match type contains an illegal case:
| ${caseText(cas)}
|(this error can be ignored for now with `-source:3.3`)"""

end MatchTypeTrace
709 changes: 493 additions & 216 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Large diffs are not rendered by default.

190 changes: 188 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Flags.*
import Names.*
import StdNames.*, NameOps.*
import NullOpsDecorator.*
import NameKinds.SkolemName
import NameKinds.{SkolemName, WildcardParamName}
import Scopes.*
import Constants.*
import Contexts.*
Expand All @@ -30,6 +30,8 @@ import Hashable.*
import Uniques.*
import collection.mutable
import config.Config
import config.Feature.sourceVersion
import config.SourceVersion
import annotation.{tailrec, constructorOnly}
import scala.util.hashing.{ MurmurHash3 => hashing }
import config.Printers.{core, typr, matchTypes}
Expand Down Expand Up @@ -5036,7 +5038,7 @@ object Types extends TypeUtils {
trace(i"reduce match type $this $hashCode", matchTypes, show = true)(inMode(Mode.Type) {
def matchCases(cmp: TrackingTypeComparer): Type =
val saved = ctx.typerState.snapshot()
try cmp.matchCases(scrutinee.normalized, cases)
try cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
catch case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
finally
Expand Down Expand Up @@ -5088,6 +5090,190 @@ object Types extends TypeUtils {
case _ => None
}

enum MatchTypeCasePattern:
case Capture(num: Int, isWildcard: Boolean)
case TypeTest(tpe: Type)
case BaseTypeTest(classType: TypeRef, argPatterns: List[MatchTypeCasePattern], needsConcreteScrut: Boolean)
case CompileTimeS(argPattern: MatchTypeCasePattern)
case AbstractTypeConstructor(tycon: Type, argPatterns: List[MatchTypeCasePattern])
case TypeMemberExtractor(typeMemberName: TypeName, capture: Capture)

def isTypeTest: Boolean =
this.isInstanceOf[TypeTest]

def needsConcreteScrutInVariantPos: Boolean = this match
case Capture(_, isWildcard) => !isWildcard
case TypeTest(_) => false
case _ => true
end MatchTypeCasePattern

enum MatchTypeCaseSpec:
case SubTypeTest(origMatchCase: Type, pattern: Type, body: Type)
case SpeccedPatMat(origMatchCase: HKTypeLambda, captureCount: Int, pattern: MatchTypeCasePattern, body: Type)
case LegacyPatMat(origMatchCase: HKTypeLambda)
case MissingCaptures(origMatchCase: HKTypeLambda, missing: collection.BitSet)

def origMatchCase: Type
end MatchTypeCaseSpec

object MatchTypeCaseSpec:
def analyze(cas: Type)(using Context): MatchTypeCaseSpec =
cas match
case cas: HKTypeLambda if !sourceVersion.isAtLeast(SourceVersion.`3.4`) =>
// Always apply the legacy algorithm under -source:3.3 and below
LegacyPatMat(cas)
case cas: HKTypeLambda =>
val defn.MatchCase(pat, body) = cas.resultType: @unchecked
val missing = checkCapturesPresent(cas, pat)
if !missing.isEmpty then
MissingCaptures(cas, missing)
else
val specPattern = tryConvertToSpecPattern(cas, pat)
if specPattern != null then
SpeccedPatMat(cas, cas.paramNames.size, specPattern, body)
else
LegacyPatMat(cas)
case _ =>
val defn.MatchCase(pat, body) = cas: @unchecked
SubTypeTest(cas, pat, body)
end analyze

/** Checks that all the captures of the case are present in the case.
*
* Sometimes, because of earlier substitutions of an abstract type constructor,
* we can end up with patterns that do not mention all their captures anymore.
* This can happen even when the body still refers to these missing captures.
* In that case, we must always consider the case to be unmatchable, i.e., to
* become `Stuck`.
*
* See pos/i12127.scala for an example.
*/
def checkCapturesPresent(cas: HKTypeLambda, pat: Type)(using Context): collection.BitSet =
val captureCount = cas.paramNames.size
val missing = new mutable.BitSet(captureCount)
missing ++= (0 until captureCount)
new CheckCapturesPresent(cas).apply(missing, pat)

private class CheckCapturesPresent(cas: HKTypeLambda)(using Context) extends TypeAccumulator[mutable.BitSet]:
def apply(missing: mutable.BitSet, tp: Type): mutable.BitSet = tp match
case TypeParamRef(binder, num) if binder eq cas =>
missing -= num
case _ =>
foldOver(missing, tp)
end CheckCapturesPresent

/** Tries to convert a match type case pattern in HKTypeLambda form into a spec'ed `MatchTypeCasePattern`.
*
* This method recovers the structure of *legal patterns* as defined in SIP-56
* from the unstructured `HKTypeLambda` coming from the typer.
*
* It must adhere to the specification of legal patterns defined at
* https://docs.scala-lang.org/sips/match-types-spec.html#legal-patterns
*
* Returns `null` if the pattern in `caseLambda` is a not a legal pattern.
*/
private def tryConvertToSpecPattern(caseLambda: HKTypeLambda, pat: Type)(using Context): MatchTypeCasePattern | Null =
var typeParamRefsAccountedFor: Int = 0

def rec(pat: Type, variance: Int): MatchTypeCasePattern | Null =
pat match
case pat @ TypeParamRef(binder, num) if binder eq caseLambda =>
typeParamRefsAccountedFor += 1
MatchTypeCasePattern.Capture(num, isWildcard = pat.paramName.is(WildcardParamName))

case pat @ AppliedType(tycon: TypeRef, args) if variance == 1 =>
val tyconSym = tycon.symbol
if tyconSym.isClass then
if tyconSym.name.startsWith("Tuple") && defn.isTupleNType(pat) then
rec(pat.toNestedPairs, variance)
else
recArgPatterns(pat) { argPatterns =>
val needsConcreteScrut = argPatterns.zip(tycon.typeParams).exists {
(argPattern, tparam) => tparam.paramVarianceSign != 0 && argPattern.needsConcreteScrutInVariantPos
}
MatchTypeCasePattern.BaseTypeTest(tycon, argPatterns, needsConcreteScrut)
}
else if defn.isCompiletime_S(tyconSym) && args.sizeIs == 1 then
val argPattern = rec(args.head, variance)
if argPattern == null then
null
else if argPattern.isTypeTest then
MatchTypeCasePattern.TypeTest(pat)
else
MatchTypeCasePattern.CompileTimeS(argPattern)
else
tycon.info match
case _: RealTypeBounds =>
recAbstractTypeConstructor(pat)
case TypeAlias(tl @ HKTypeLambda(onlyParam :: Nil, resType: RefinedType)) =>
/* Unlike for eta-expanded classes, the typer does not automatically
* dealias poly type aliases to refined types. So we have to give them
* a chance here.
* We are quite specific about the shape of type aliases that we are willing
* to dealias this way, because we must not dealias arbitrary type constructors
* that could refine the bounds of the captures; those would amount of
* type-test + capture combos, which are out of the specced match types.
*/
rec(pat.superType, variance)
case _ =>
null

case pat @ AppliedType(tycon: TypeParamRef, _) if variance == 1 =>
recAbstractTypeConstructor(pat)

case pat @ RefinedType(parent, refinedName: TypeName, TypeAlias(alias @ TypeParamRef(binder, num)))
if variance == 1 && (binder eq caseLambda) =>
parent.member(refinedName) match
case refinedMember: SingleDenotation if refinedMember.exists =>
// Check that the bounds of the capture contain the bounds of the inherited member
val refinedMemberBounds = refinedMember.info
val captureBounds = caseLambda.paramInfos(num)
if captureBounds.contains(refinedMemberBounds) then
/* In this case, we know that any member we eventually find during reduction
* will have bounds that fit in the bounds of the capture. Therefore, no
* type-test + capture combo is necessary, and we can apply the specced match types.
*/
val capture = rec(alias, variance = 0).asInstanceOf[MatchTypeCasePattern.Capture]
MatchTypeCasePattern.TypeMemberExtractor(refinedName, capture)
else
// Otherwise, a type-test + capture combo might be necessary, and we are out of spec
null
case _ =>
// If the member does not refine a member of the `parent`, we are out of spec
null

case _ =>
MatchTypeCasePattern.TypeTest(pat)
end rec

def recAbstractTypeConstructor(pat: AppliedType): MatchTypeCasePattern | Null =
recArgPatterns(pat) { argPatterns =>
MatchTypeCasePattern.AbstractTypeConstructor(pat.tycon, argPatterns)
}
end recAbstractTypeConstructor

def recArgPatterns(pat: AppliedType)(whenNotTypeTest: List[MatchTypeCasePattern] => MatchTypeCasePattern | Null): MatchTypeCasePattern | Null =
val AppliedType(tycon, args) = pat
val tparams = tycon.typeParams
val argPatterns = args.zip(tparams).map { (arg, tparam) =>
rec(arg, tparam.paramVarianceSign)
}
if argPatterns.exists(_ == null) then
null
else
val argPatterns1 = argPatterns.asInstanceOf[List[MatchTypeCasePattern]] // they are not null
if argPatterns1.forall(_.isTypeTest) then
MatchTypeCasePattern.TypeTest(pat)
else
whenNotTypeTest(argPatterns1)
end recArgPatterns

val result = rec(pat, variance = 1)
if typeParamRefsAccountedFor == caseLambda.paramNames.size then result
else null
end tryConvertToSpecPattern
end MatchTypeCaseSpec

// ------ ClassInfo, Type Bounds --------------------------------------------------

type TypeOrSymbol = Type | Symbol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
case VarArgsParamCannotBeGivenID // errorNumber: 188
case ExtractorNotFoundID // errorNumber: 189
case PureUnitExpressionID // errorNumber: 190
case MatchTypeLegacyPatternID // errorNumber: 191

def errorNumber = ordinal - 1

Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3073,6 +3073,10 @@ class MatchTypeScrutineeCannotBeHigherKinded(tp: Type)(using Context)
def msg(using Context) = i"the scrutinee of a match type cannot be higher-kinded"
def explain(using Context) = ""

class MatchTypeLegacyPattern(errorText: String)(using Context) extends TypeMsg(MatchTypeLegacyPatternID):
def msg(using Context) = errorText
def explain(using Context) = ""

class ClosureCannotHaveInternalParameterDependencies(mt: Type)(using Context)
extends TypeMsg(ClosureCannotHaveInternalParameterDependenciesID):
def msg(using Context) =
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ object TypeTestsCasts {

case x =>
// always false test warnings are emitted elsewhere
TypeComparer.provablyDisjoint(x, tpe.derivedAppliedType(tycon, targs.map(_ => WildcardType)))
// provablyDisjoint wants fully applied types as input; because we're in the middle of erasure, we sometimes get raw types here
val xApplied =
val tparams = x.typeParams
if tparams.isEmpty then x else x.appliedTo(tparams.map(_ => WildcardType))
TypeComparer.provablyDisjoint(xApplied, tpe.derivedAppliedType(tycon, targs.map(_ => WildcardType)))
|| typeArgsDeterminable(X, tpe)
||| i"its type arguments can't be determined from $X"
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ trait TypeAssigner {
else fntpe.resultType // fast path optimization
else
errorType(em"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos)
case err: ErrorType =>
err
case t =>
if (ctx.settings.Ydebug.value) new FatalError("").printStackTrace()
errorType(err.takesNoParamsMsg(fn, ""), tree.srcPos)
Expand Down Expand Up @@ -563,5 +565,3 @@ object TypeAssigner extends TypeAssigner:
def seqLitType(tree: untpd.SeqLiteral, elemType: Type)(using Context) = tree match
case tree: untpd.JavaSeqLiteral => defn.ArrayOf(elemType)
case _ => if ctx.erasedTypes then defn.SeqType else defn.SeqType.appliedTo(elemType)


21 changes: 0 additions & 21 deletions tests/neg/12800.scala

This file was deleted.

9 changes: 5 additions & 4 deletions tests/neg/6314-1.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
object G {
final class X
final class Y
trait X
class Y
class Z

trait FooSig {
type Type
Expand All @@ -13,14 +14,14 @@ object G {
type Foo = Foo.Type

type Bar[A] = A match {
case X & Y => String
case X & Z => String
case Y => Int
}

def main(args: Array[String]): Unit = {
val a: Bar[X & Y] = "hello" // error
val i: Bar[Y & Foo] = Foo.apply[Bar](a)
val b: Int = i // error
val b: Int = i
println(b + 1)
}
}
16 changes: 16 additions & 0 deletions tests/neg/6314-6.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- Error: tests/neg/6314-6.scala:26:3 ----------------------------------------------------------------------------------
26 | (new YY {}).boom // error: object creation impossible
| ^
|object creation impossible, since def apply(fa: String): Int in trait XX in object Test3 is not defined
|(Note that
| parameter String in def apply(fa: String): Int in trait XX in object Test3 does not match
| parameter Test3.Bar[X & Object with Test3.YY {...}#Foo] in def apply(fa: Test3.Bar[X & YY.this.Foo]): Test3.Bar[Y & YY.this.Foo] in trait YY in object Test3
| )
-- Error: tests/neg/6314-6.scala:52:3 ----------------------------------------------------------------------------------
52 | (new YY {}).boom // error: object creation impossible
| ^
|object creation impossible, since def apply(fa: String): Int in trait XX in object Test4 is not defined
|(Note that
| parameter String in def apply(fa: String): Int in trait XX in object Test4 does not match
| parameter Test4.Bar[X & Object with Test4.YY {...}#FooAlias] in def apply(fa: Test4.Bar[X & YY.this.FooAlias]): Test4.Bar[Y & YY.this.FooAlias] in trait YY in object Test4
| )
Loading

0 comments on commit 881e945

Please sign in to comment.