Skip to content

Commit

Permalink
Rework ProtoType's constrained API
Browse files Browse the repository at this point in the history
Remove the single use overload, replace with a much more used
alternative

Also return TypeVars instead of TypeTrees, so we don't have to unwrap
the useless wrapper a bunch of times, and instead we wrap the few times
we really do want to.
  • Loading branch information
dwijnand committed Sep 29, 2023
1 parent d5a6d4f commit 5b57e09
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 43 deletions.
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ object Trees {
/** A type tree that represents an existing or inferred type */
case class TypeTree[+T <: Untyped]()(implicit @constructorOnly src: SourceFile)
extends DenotingTree[T] with TypTree[T] {
type ThisTree[+T <: Untyped] = TypeTree[T]
type ThisTree[+T <: Untyped] <: TypeTree[T]
override def isEmpty: Boolean = !hasType
override def toString: String =
s"TypeTree${if (hasType) s"[$typeOpt]" else ""}"
Expand All @@ -794,7 +794,8 @@ object Trees {
* - as a (result-)type of an inferred ValDef or DefDef.
* Every TypeVar is created as the type of one InferredTypeTree.
*/
class InferredTypeTree[+T <: Untyped](implicit @constructorOnly src: SourceFile) extends TypeTree[T]
class InferredTypeTree[+T <: Untyped](implicit @constructorOnly src: SourceFile) extends TypeTree[T]:
type ThisTree[+T <: Untyped] <: InferredTypeTree[T]

/** ref.type */
case class SingletonTypeTree[+T <: Untyped] private[ast] (ref: Tree[T])(implicit @constructorOnly src: SourceFile)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3253,7 +3253,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
def matchCase(cas: Type): MatchResult = trace(i"$scrut match ${MatchTypeTrace.caseText(cas)}", matchTypes, show = true) {
val cas1 = cas match {
case cas: HKTypeLambda =>
caseLambda = constrained(cas)
caseLambda = constrained(cas, ast.tpd.EmptyTree)._1
caseLambda.resultType
case _ =>
cas
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4941,6 +4941,9 @@ object Types {
if (inst.exists) inst else origin
}

def wrapInTypeTree(owningTree: Tree)(using Context): InferredTypeTree =
new InferredTypeTree().withSpan(owningTree.span).withType(this)

override def computeHash(bs: Binders): Int = identityHash(bs)
override def equals(that: Any): Boolean = this.eq(that.asInstanceOf[AnyRef])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,9 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
(rawRef, rawInfo)
baseInfo match
case tl: PolyType =>
val (tl1, tpts) = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)
val targs =
for (tpt <- tpts) yield
tpt.tpe match {
case tvar: TypeVar => tvar.instantiate(fromBelow = false)
}
val tvars = constrained(tl)
val targs = for tvar <- tvars yield
tvar.instantiate(fromBelow = false)
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
case methTpe =>
(baseRef, extractParams(methTpe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ object TypeTestsCasts {
case tp: TypeProxy => underlyingLambda(tp.superType)
}
val typeLambda = underlyingLambda(tycon)
val tvars = constrained(typeLambda, untpd.EmptyTree, alwaysAddTypeVars = true)._2.map(_.tpe)
val tvars = constrained(typeLambda)
val P1 = tycon.appliedTo(tvars)

debug.println("before " + ctx.typerState.constraint.show)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ object SpaceEngine {
val mt: MethodType = unapp.widen match {
case mt: MethodType => mt
case pt: PolyType =>
val tvars = constrained(pt, EmptyTree)._2.tpes
val tvars = constrained(pt)
val mt = pt.instantiate(tvars).asInstanceOf[MethodType]
scrutineeTp <:< mt.paramInfos(0)
// force type inference to infer a narrower type: could be singleton
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ object Inferencing {
def inferTypeParams(tree: Tree, pt: Type)(using Context): Tree = tree.tpe match
case tl: TypeLambda =>
val (tl1, tvars) = constrained(tl, tree)
var tree1 = AppliedTypeTree(tree.withType(tl1), tvars)
val tree1 = AppliedTypeTree(tree.withType(tl1), tvars.map(_.wrapInTypeTree(tree)))
tree1.tpe <:< pt
if isFullyDefined(tree1.tpe, force = ForceDegree.failBottom) then
tree1
Expand Down
30 changes: 14 additions & 16 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -726,41 +726,39 @@ object ProtoTypes {
tl: TypeLambda, owningTree: untpd.Tree,
alwaysAddTypeVars: Boolean,
nestingLevel: Int = ctx.nestingLevel
): (TypeLambda, List[TypeTree]) = {
): (TypeLambda, List[TypeVar]) = {
val state = ctx.typerState
val addTypeVars = alwaysAddTypeVars || !owningTree.isEmpty
if (tl.isInstanceOf[PolyType])
assert(!ctx.typerState.isCommittable || addTypeVars,
s"inconsistent: no typevars were added to committable constraint ${state.constraint}")
// hk type lambdas can be added to constraints without typevars during match reduction

def newTypeVars(tl: TypeLambda): List[TypeTree] =
for (paramRef <- tl.paramRefs)
yield {
val tt = InferredTypeTree().withSpan(owningTree.span)
def newTypeVars(tl: TypeLambda): List[TypeVar] =
for paramRef <- tl.paramRefs
yield
val tvar = TypeVar(paramRef, state, nestingLevel)
state.ownedVars += tvar
tt.withType(tvar)
}
tvar

val added = state.constraint.ensureFresh(tl)
val tvars = if (addTypeVars) newTypeVars(added) else Nil
TypeComparer.addToConstraint(added, tvars.tpes.asInstanceOf[List[TypeVar]])
val tvars = if addTypeVars then newTypeVars(added) else Nil
TypeComparer.addToConstraint(added, tvars)
(added, tvars)
}

def constrained(tl: TypeLambda, owningTree: untpd.Tree)(using Context): (TypeLambda, List[TypeTree]) =
def constrained(tl: TypeLambda, owningTree: untpd.Tree)(using Context): (TypeLambda, List[TypeVar]) =
constrained(tl, owningTree,
alwaysAddTypeVars = tl.isInstanceOf[PolyType] && ctx.typerState.isCommittable)

/** Same as `constrained(tl, EmptyTree)`, but returns just the created type lambda */
def constrained(tl: TypeLambda)(using Context): TypeLambda =
constrained(tl, EmptyTree)._1
/** Same as `constrained(tl, EmptyTree, alwaysAddTypeVars = true)`, but returns just the created type vars. */
def constrained(tl: TypeLambda)(using Context): List[TypeVar] =
constrained(tl, EmptyTree, alwaysAddTypeVars = true)._2

/** Instantiate `tl` with fresh type variables added to the constraint. */
def instantiateWithTypeVars(tl: TypeLambda)(using Context): Type =
val targs = constrained(tl, ast.tpd.EmptyTree, alwaysAddTypeVars = true)._2
tl.instantiate(targs.tpes)
val tvars = constrained(tl)
tl.instantiate(tvars)

/** A fresh type variable added to the current constraint.
* @param bounds The initial bounds of the variable
Expand All @@ -779,7 +777,7 @@ object ProtoTypes {
pt => bounds :: Nil,
pt => represents.orElse(defn.AnyType))
constrained(poly, untpd.EmptyTree, alwaysAddTypeVars = true, nestingLevel)
._2.head.tpe.asInstanceOf[TypeVar]
._2.head

/** If `param` was created using `newTypeVar(..., represents = X)`, returns X.
* This is used in:
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4336,7 +4336,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
var typeArgs = tree match
case Select(qual, nme.CONSTRUCTOR) => qual.tpe.widenDealias.argTypesLo.map(TypeTree(_))
case _ => Nil
if typeArgs.isEmpty then typeArgs = constrained(poly, tree)._2
if typeArgs.isEmpty then typeArgs = constrained(poly, tree)._2.map(_.wrapInTypeTree(tree))
convertNewGenericArray(readapt(tree.appliedToTypeTrees(typeArgs)))
case wtp =>
val isStructuralCall = wtp.isValueType && isStructuralTermSelectOrApply(tree)
Expand Down
4 changes: 2 additions & 2 deletions compiler/test/dotty/tools/SignatureTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SignatureTest:
| def tuple2(x: Foo *: (T | Tuple) & Foo): Unit = {}
|""".stripMargin):
val cls = requiredClass("A")
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda], untpd.EmptyTree, alwaysAddTypeVars = true)._2.head.tpe
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda]).head
tvar <:< defn.TupleTypeRef
val prefix = cls.typeRef.appliedTo(tvar)

Expand All @@ -89,7 +89,7 @@ class SignatureTest:
| def and(x: T & Foo): Unit = {}
|""".stripMargin):
val cls = requiredClass("A")
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda], untpd.EmptyTree, alwaysAddTypeVars = true)._2.head.tpe
val tvar = constrained(cls.requiredMethod(nme.CONSTRUCTOR).info.asInstanceOf[TypeLambda]).head
val prefix = cls.typeRef.appliedTo(tvar)
val ref = prefix.select(cls.requiredMethod("and")).asInstanceOf[TermRef]

Expand Down
19 changes: 7 additions & 12 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class ConstraintsTest:
@Test def mergeParamsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T, R]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t, r) = tvars.tpes
val List(s, t, r) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
Expand All @@ -38,8 +37,7 @@ class ConstraintsTest:
@Test def mergeBoundsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes
val List(s, t) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
Expand All @@ -57,32 +55,29 @@ class ConstraintsTest:
@Test def validBoundsInit: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes
val List(s, t) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any
}

@Test def validBoundsUnify: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes
val List(s, t) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])

s <:< t

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}")
}

@Test def validBoundsReplace: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait X; trait A { def foo[S <: U | X, T, U]: Any }") {
val tvarTrees = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val tvars @ List(s, t, u) = tvarTrees.tpes.asInstanceOf[List[TypeVar]]
val tvars @ List(s, t, u) = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
s =:= t
t =:= u

Expand Down

0 comments on commit 5b57e09

Please sign in to comment.