Skip to content

Commit 29082e1

Browse files
committed
Add infrastructure for setting paramVariances in HKTypeLambdas
This is meant as a better alternative to encode variances in parameter names.
1 parent cf75785 commit 29082e1

File tree

3 files changed

+94
-47
lines changed

3 files changed

+94
-47
lines changed

compiler/src/dotty/tools/dotc/core/Hashable.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import scala.util.hashing.{ MurmurHash3 => hashing }
66
import annotation.tailrec
77

88
object Hashable {
9-
9+
1010
/** A null terminated list of BindingTypes. We use `null` here for efficiency */
1111
class Binders(val tp: BindingType, val next: Binders)
1212

@@ -44,7 +44,7 @@ trait Hashable {
4444
avoidSpecialHashes(hashing.finalizeHash(hashCode, arity))
4545

4646
final def typeHash(bs: Binders, tp: Type): Int =
47-
if (bs == null || tp.stableHash) tp.hash else tp.computeHash(bs)
47+
if (bs == null || tp.hashIsStable) tp.hash else tp.computeHash(bs)
4848

4949
def identityHash(bs: Binders): Int = avoidSpecialHashes(System.identityHashCode(this))
5050

@@ -80,7 +80,7 @@ trait Hashable {
8080
finishHash(bs, hashing.mix(seed, elemHash), arity + 1, tps)
8181
}
8282

83-
83+
8484
protected final def doHash(x: Any): Int =
8585
finishHash(hashing.mix(hashSeed, x.hashCode), 1)
8686

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import Decorators._
1818
import Denotations._
1919
import Periods._
2020
import CheckRealizable._
21+
import Variances.{Variance, varianceFromInt, varianceToInt}
2122
import util.Stats._
2223
import util.SimpleIdentitySet
2324
import reporting.diagnostic.Message
@@ -29,7 +30,8 @@ import Hashable._
2930
import Uniques._
3031
import collection.mutable
3132
import config.Config
32-
import annotation.tailrec
33+
import annotation.{tailrec, constructorOnly}
34+
3335
import language.implicitConversions
3436
import scala.util.hashing.{ MurmurHash3 => hashing }
3537
import config.Printers.{core, typr}
@@ -1670,7 +1672,7 @@ object Types {
16701672
def computeHash(bs: Binders): Int
16711673

16721674
/** Is the `hash` of this type the same for all possible sequences of enclosing binders? */
1673-
def stableHash: Boolean = true
1675+
def hashIsStable: Boolean = true
16741676
}
16751677

16761678
// end Type
@@ -2331,8 +2333,8 @@ object Types {
23312333

23322334
override def computeHash(bs: Binders): Int = doHash(bs, designator, prefix)
23332335

2334-
override def stableHash: Boolean = {
2335-
if (myStableHash == 0) myStableHash = if (prefix.stableHash) 1 else -1
2336+
override def hashIsStable: Boolean = {
2337+
if (myStableHash == 0) myStableHash = if (prefix.hashIsStable) 1 else -1
23362338
myStableHash > 0
23372339
}
23382340

@@ -2616,7 +2618,7 @@ object Types {
26162618
else parent
26172619

26182620
override def computeHash(bs: Binders): Int = doHash(bs, refinedName, refinedInfo, parent)
2619-
override def stableHash: Boolean = refinedInfo.stableHash && parent.stableHash
2621+
override def hashIsStable: Boolean = refinedInfo.hashIsStable && parent.hashIsStable
26202622

26212623
override def eql(that: Type): Boolean = that match {
26222624
case that: RefinedType =>
@@ -2717,7 +2719,7 @@ object Types {
27172719

27182720
override def computeHash(bs: Binders): Int = doHash(new Binders(this, bs), parent)
27192721

2720-
override def stableHash: Boolean = false
2722+
override def hashIsStable: Boolean = false
27212723
// this is a conservative observation. By construction RecTypes contain at least
27222724
// one RecThis occurrence. Since `stableHash` does not keep track of enclosing
27232725
// bound types, it will return "unstable" for this occurrence and this would propagate.
@@ -3025,7 +3027,7 @@ object Types {
30253027
if (resType eq this.resType) this else ExprType(resType)
30263028

30273029
override def computeHash(bs: Binders): Int = doHash(bs, resType)
3028-
override def stableHash: Boolean = resType.stableHash
3030+
override def hashIsStable: Boolean = resType.hashIsStable
30293031

30303032
override def eql(that: Type): Boolean = that match {
30313033
case that: ExprType => resType.eq(that.resType)
@@ -3110,38 +3112,19 @@ object Types {
31103112
if ((paramNames eq this.paramNames) && (paramInfos eq this.paramInfos) && (resType eq this.resType)) this
31113113
else newLikeThis(paramNames, paramInfos, resType)
31123114

3113-
final def newLikeThis(paramNames: List[ThisName], paramInfos: List[PInfo], resType: Type)(implicit ctx: Context): This =
3115+
def newLikeThis(paramNames: List[ThisName], paramInfos: List[PInfo], resType: Type)(implicit ctx: Context): This =
31143116
companion(paramNames)(
31153117
x => paramInfos.mapConserve(_.subst(this, x).asInstanceOf[PInfo]),
31163118
x => resType.subst(this, x))
31173119

31183120
protected def prefixString: String
3119-
final override def toString: String = s"$prefixString($paramNames, $paramInfos, $resType)"
3121+
override def toString: String = s"$prefixString($paramNames, $paramInfos, $resType)"
31203122
}
31213123

31223124
abstract class HKLambda extends CachedProxyType with LambdaType {
31233125
final override def underlying(implicit ctx: Context): Type = resType
3124-
3125-
override def computeHash(bs: Binders): Int =
3126-
doHash(new Binders(this, bs), paramNames, resType, paramInfos)
3127-
3128-
override def stableHash: Boolean = resType.stableHash && paramInfos.stableHash
3129-
3126+
final override def hashIsStable: Boolean = resType.hashIsStable && paramInfos.hashIsStable
31303127
final override def equals(that: Any): Boolean = equals(that, null)
3131-
3132-
// No definition of `eql` --> fall back on equals, which calls iso
3133-
3134-
final override def iso(that: Any, bs: BinderPairs): Boolean = that match {
3135-
case that: HKLambda =>
3136-
paramNames.eqElements(that.paramNames) &&
3137-
companion.eq(that.companion) && {
3138-
val bs1 = new BinderPairs(this, that, bs)
3139-
paramInfos.equalElements(that.paramInfos, bs1) &&
3140-
resType.equals(that.resType, bs1)
3141-
}
3142-
case _ =>
3143-
false
3144-
}
31453128
}
31463129

31473130
abstract class MethodOrPoly extends UncachedGroundType with LambdaType with MethodicType {
@@ -3430,8 +3413,12 @@ object Types {
34303413

34313414
def newParamRef(n: Int): TypeParamRef = new TypeParamRefImpl(this, n)
34323415

3433-
@threadUnsafe lazy val typeParams: List[LambdaParam] =
3434-
paramNames.indices.toList.map(new LambdaParam(this, _))
3416+
protected var myTypeParams: List[LambdaParam] = null
3417+
3418+
def typeParams: List[LambdaParam] =
3419+
if myTypeParams == null then
3420+
myTypeParams = paramNames.indices.toList.map(new LambdaParam(this, _))
3421+
myTypeParams
34353422

34363423
def derivedLambdaAbstraction(paramNames: List[TypeName], paramInfos: List[TypeBounds], resType: Type)(implicit ctx: Context): Type =
34373424
resType match {
@@ -3457,7 +3444,7 @@ object Types {
34573444
* @param resultTypeExp A function that, given the polytype itself, returns the
34583445
* result type `T`.
34593446
*/
3460-
class HKTypeLambda(val paramNames: List[TypeName])(
3447+
class HKTypeLambda(val paramNames: List[TypeName], @constructorOnly variances: List[Variance])(
34613448
paramInfosExp: HKTypeLambda => List[TypeBounds], resultTypeExp: HKTypeLambda => Type)
34623449
extends HKLambda with TypeLambda {
34633450
type This = HKTypeLambda
@@ -3469,7 +3456,49 @@ object Types {
34693456
assert(resType.isInstanceOf[TermType], this)
34703457
assert(paramNames.nonEmpty)
34713458

3459+
private def setVariances(tparams: List[LambdaParam], vs: List[Variance]): Unit =
3460+
if tparams.nonEmpty then
3461+
tparams.head.setVariance(vs.head)
3462+
setVariances(tparams.tail, vs.tail)
3463+
3464+
val isVariant = variances.nonEmpty
3465+
if isVariant then setVariances(typeParams, variances)
3466+
3467+
def givenVariances =
3468+
if isVariant then typeParams.map(_.paramVariance)
3469+
else Nil
3470+
3471+
override def computeHash(bs: Binders): Int =
3472+
doHash(new Binders(this, bs), givenVariances ::: paramNames, resType, paramInfos)
3473+
3474+
// No definition of `eql` --> fall back on equals, which calls iso
3475+
3476+
final override def iso(that: Any, bs: BinderPairs): Boolean = that match {
3477+
case that: HKTypeLambda =>
3478+
paramNames.eqElements(that.paramNames)
3479+
&& isVariant == that.isVariant
3480+
&& (!isVariant
3481+
|| typeParams.corresponds(that.typeParams)((x, y) =>
3482+
x.paramVariance == y.paramVariance))
3483+
&& {
3484+
val bs1 = new BinderPairs(this, that, bs)
3485+
paramInfos.equalElements(that.paramInfos, bs1) &&
3486+
resType.equals(that.resType, bs1)
3487+
}
3488+
case _ =>
3489+
false
3490+
}
3491+
3492+
override def newLikeThis(paramNames: List[ThisName], paramInfos: List[PInfo], resType: Type)(implicit ctx: Context): This =
3493+
HKTypeLambda(paramNames, givenVariances)(
3494+
x => paramInfos.mapConserve(_.subst(this, x).asInstanceOf[PInfo]),
3495+
x => resType.subst(this, x))
3496+
34723497
protected def prefixString: String = "HKTypeLambda"
3498+
final override def toString: String =
3499+
if isVariant then
3500+
s"HKTypeLambda($paramNames, $paramInfos, $resType, ${givenVariances.map(_.flagsString)})"
3501+
else super.toString
34733502
}
34743503

34753504
/** The type of a polymorphic method. It has the same form as HKTypeLambda,
@@ -3519,7 +3548,12 @@ object Types {
35193548
def apply(paramNames: List[TypeName])(
35203549
paramInfosExp: HKTypeLambda => List[TypeBounds],
35213550
resultTypeExp: HKTypeLambda => Type)(implicit ctx: Context): HKTypeLambda =
3522-
unique(new HKTypeLambda(paramNames)(paramInfosExp, resultTypeExp))
3551+
apply(paramNames, Nil)(paramInfosExp, resultTypeExp)
3552+
3553+
def apply(paramNames: List[TypeName], variances: List[Variance])(
3554+
paramInfosExp: HKTypeLambda => List[TypeBounds],
3555+
resultTypeExp: HKTypeLambda => Type)(implicit ctx: Context): HKTypeLambda =
3556+
unique(new HKTypeLambda(paramNames, variances)(paramInfosExp, resultTypeExp))
35233557

35243558
def unapply(tl: HKTypeLambda): Some[(List[LambdaParam], Type)] =
35253559
Some((tl.typeParams, tl.resType))
@@ -3580,13 +3614,21 @@ object Types {
35803614
/** The parameter of a type lambda */
35813615
case class LambdaParam(tl: TypeLambda, n: Int) extends ParamInfo {
35823616
type ThisName = TypeName
3617+
35833618
def isTypeParam(implicit ctx: Context): Boolean = tl.paramNames.head.isTypeName
35843619
def paramName(implicit ctx: Context): tl.ThisName = tl.paramNames(n)
35853620
def paramInfo(implicit ctx: Context): tl.PInfo = tl.paramInfos(n)
35863621
def paramInfoAsSeenFrom(pre: Type)(implicit ctx: Context): tl.PInfo = paramInfo
35873622
def paramInfoOrCompleter(implicit ctx: Context): Type = paramInfo
35883623
def paramVarianceSign(implicit ctx: Context): Int = tl.paramNames(n).variance
35893624
def paramRef(implicit ctx: Context): Type = tl.paramRefs(n)
3625+
3626+
private var myVariance: FlagSet = UndefinedFlags
3627+
def setVariance(v: Variance): Unit = myVariance = v
3628+
def paramVariance: Variance =
3629+
if myVariance == UndefinedFlags then
3630+
myVariance = varianceFromInt(tl.paramNames(n).variance)
3631+
myVariance
35903632
}
35913633

35923634
/** A type application `C[T_1, ..., T_n]` */
@@ -3758,8 +3800,8 @@ object Types {
37583800

37593801
override def computeHash(bs: Binders): Int = doHash(bs, tycon, args)
37603802

3761-
override def stableHash: Boolean = {
3762-
if (myStableHash == 0) myStableHash = if (tycon.stableHash && args.stableHash) 1 else -1
3803+
override def hashIsStable: Boolean = {
3804+
if (myStableHash == 0) myStableHash = if (tycon.hashIsStable && args.hashIsStable) 1 else -1
37633805
myStableHash > 0
37643806
}
37653807

@@ -3790,7 +3832,7 @@ object Types {
37903832
type BT <: Type
37913833
val binder: BT
37923834
def copyBoundType(bt: BT): Type
3793-
override def stableHash: Boolean = false
3835+
override def hashIsStable: Boolean = false
37943836
}
37953837

37963838
abstract class ParamRef extends BoundType {
@@ -4193,7 +4235,7 @@ object Types {
41934235
else ClassInfo(prefix, cls, classParents, decls, selfInfo)
41944236

41954237
override def computeHash(bs: Binders): Int = doHash(bs, cls, prefix)
4196-
override def stableHash: Boolean = prefix.stableHash && classParents.stableHash
4238+
override def hashIsStable: Boolean = prefix.hashIsStable && classParents.hashIsStable
41974239

41984240
override def eql(that: Type): Boolean = that match {
41994241
case that: ClassInfo =>
@@ -4290,7 +4332,7 @@ object Types {
42904332
}
42914333

42924334
override def computeHash(bs: Binders): Int = doHash(bs, lo, hi)
4293-
override def stableHash: Boolean = lo.stableHash && hi.stableHash
4335+
override def hashIsStable: Boolean = lo.hashIsStable && hi.hashIsStable
42944336

42954337
override def equals(that: Any): Boolean = equals(that, null)
42964338

@@ -4315,7 +4357,7 @@ object Types {
43154357
def derivedAlias(alias: Type)(implicit ctx: Context): AliasingBounds
43164358

43174359
override def computeHash(bs: Binders): Int = doHash(bs, alias)
4318-
override def stableHash: Boolean = alias.stableHash
4360+
override def hashIsStable: Boolean = alias.hashIsStable
43194361

43204362
override def iso(that: Any, bs: BinderPairs): Boolean = that match {
43214363
case that: AliasingBounds => this.isTypeAlias == that.isTypeAlias && alias.equals(that.alias, bs)
@@ -4416,7 +4458,7 @@ object Types {
44164458
if (elemtp eq this.elemType) this else JavaArrayType(elemtp)
44174459

44184460
override def computeHash(bs: Binders): Int = doHash(bs, elemType)
4419-
override def stableHash: Boolean = elemType.stableHash
4461+
override def hashIsStable: Boolean = elemType.hashIsStable
44204462

44214463
override def eql(that: Type): Boolean = that match {
44224464
case that: JavaArrayType => elemType.eq(that.elemType)
@@ -4479,7 +4521,7 @@ object Types {
44794521
else WildcardType(optBounds.asInstanceOf[TypeBounds])
44804522

44814523
override def computeHash(bs: Binders): Int = doHash(bs, optBounds)
4482-
override def stableHash: Boolean = optBounds.stableHash
4524+
override def hashIsStable: Boolean = optBounds.hashIsStable
44834525

44844526
override def eql(that: Type): Boolean = that match {
44854527
case that: WildcardType => optBounds.eq(that.optBounds)
@@ -5386,8 +5428,8 @@ object Types {
53865428
implicit def decorateTypeApplications(tpe: Type): TypeApplications = new TypeApplications(tpe)
53875429

53885430
implicit class typeListDeco(val tps1: List[Type]) extends AnyVal {
5389-
@tailrec def stableHash: Boolean =
5390-
tps1.isEmpty || tps1.head.stableHash && tps1.tail.stableHash
5431+
@tailrec def hashIsStable: Boolean =
5432+
tps1.isEmpty || tps1.head.hashIsStable && tps1.tail.hashIsStable
53915433
@tailrec def equalElements(tps2: List[Type], bs: BinderPairs): Boolean =
53925434
(tps1 `eq` tps2) || {
53935435
if (tps1.isEmpty) tps2.isEmpty

compiler/src/dotty/tools/dotc/core/Variances.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@ object Variances {
1010
val Bivariant: Variance = VarianceFlags
1111
val Invariant: Variance = EmptyFlags
1212

13-
def varianceFromInt(v: Int) =
13+
def varianceFromInt(v: Int): Variance =
1414
if v < 0 then Covariant
1515
else if v > 0 then Contravariant
1616
else Invariant
1717

18+
def varianceToInt(v: Variance): Int =
19+
if v.is(Covariant) then 1
20+
else if v.is(Contravariant) then -1
21+
else 0
22+
1823
/** Flip between covariant and contravariant */
1924
def flip(v: Variance): Variance =
2025
if (v == Covariant) Contravariant

0 commit comments

Comments
 (0)