Skip to content

Commit 3649fa8

Browse files
authored
Merge pull request #3903 from dotty-staging/fix-#2663
Fix #2663: More refined handling of enum case apply results
2 parents cdc25fd + 393c620 commit 3649fa8

File tree

5 files changed

+122
-32
lines changed

5 files changed

+122
-32
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+38-26
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ object desugar {
4040
def derivedType(sym: Symbol)(implicit ctx: Context) = sym.typeRef
4141
}
4242

43-
class DerivedFromParamTree extends DerivedTypeTree {
43+
/** A type tree that computes its type from an existing parameter.
44+
* @param suffix String difference between existing parameter (call it `P`) and parameter owning the
45+
* DerivedTypeTree (call it `O`). We have: `O.name == P.name + suffix`.
46+
*/
47+
class DerivedFromParamTree(suffix: String) extends DerivedTypeTree {
4448

4549
/** Make sure that for all enclosing module classes their companion lasses
4650
* are completed. Reason: We need the constructor of such companion classes to
@@ -58,24 +62,28 @@ object desugar {
5862

5963
/** Return info of original symbol, where all references to siblings of the
6064
* original symbol (i.e. sibling and original symbol have the same owner)
61-
* are rewired to same-named parameters or accessors in the scope enclosing
65+
* are rewired to like-named* parameters or accessors in the scope enclosing
6266
* the current scope. The current scope is the scope owned by the defined symbol
6367
* itself, that's why we have to look one scope further out. If the resulting
6468
* type is an alias type, dealias it. This is necessary because the
6569
* accessor of a type parameter is a private type alias that cannot be accessed
6670
* from subclasses.
71+
*
72+
* (*) like-named means:
73+
*
74+
* parameter name == reference name ++ suffix
6775
*/
6876
def derivedType(sym: Symbol)(implicit ctx: Context) = {
6977
val relocate = new TypeMap {
7078
val originalOwner = sym.owner
7179
def apply(tp: Type) = tp match {
7280
case tp: NamedType if tp.symbol.exists && (tp.symbol.owner eq originalOwner) =>
7381
val defctx = ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next()
74-
var local = defctx.denotNamed(tp.name).suchThat(_.isParamOrAccessor).symbol
82+
var local = defctx.denotNamed(tp.name ++ suffix).suchThat(_.isParamOrAccessor).symbol
7583
if (local.exists) (defctx.owner.thisType select local).dealias
7684
else {
7785
def msg =
78-
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope}"
86+
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope.toList}"
7987
if (ctx.reporter.errorsReported) ErrorType(msg)
8088
else throw new java.lang.Error(msg)
8189
}
@@ -88,14 +96,20 @@ object desugar {
8896
}
8997

9098
/** A type definition copied from `tdef` with a rhs typetree derived from it */
91-
def derivedTypeParam(tdef: TypeDef) =
99+
def derivedTypeParam(tdef: TypeDef, suffix: String = ""): TypeDef =
92100
cpy.TypeDef(tdef)(
93-
rhs = new DerivedFromParamTree() withPos tdef.rhs.pos watching tdef)
101+
name = tdef.name ++ suffix,
102+
rhs = new DerivedFromParamTree(suffix).withPos(tdef.rhs.pos).watching(tdef)
103+
)
104+
105+
/** A derived type definition watching `sym` */
106+
def derivedTypeParam(sym: TypeSymbol)(implicit ctx: Context): TypeDef =
107+
TypeDef(sym.name, new DerivedFromParamTree("").watching(sym)).withFlags(TypeParam)
94108

95109
/** A value definition copied from `vdef` with a tpt typetree derived from it */
96110
def derivedTermParam(vdef: ValDef) =
97111
cpy.ValDef(vdef)(
98-
tpt = new DerivedFromParamTree() withPos vdef.tpt.pos watching vdef)
112+
tpt = new DerivedFromParamTree("") withPos vdef.tpt.pos watching vdef)
99113

100114
// ----- Desugar methods -------------------------------------------------
101115

@@ -317,8 +331,8 @@ object desugar {
317331
}
318332
def anyRef = ref(defn.AnyRefAlias.typeRef)
319333

320-
val derivedTparams = constrTparams map derivedTypeParam
321-
val derivedVparamss = constrVparamss nestedMap derivedTermParam
334+
val derivedTparams = constrTparams.map(derivedTypeParam(_))
335+
val derivedVparamss = constrVparamss.nestedMap(derivedTermParam(_))
322336
val arity = constrVparamss.head.length
323337

324338
val classTycon: Tree = new TypeRefTree // watching is set at end of method
@@ -419,9 +433,8 @@ object desugar {
419433
// ev1: Eq[T1$1, T1$2], ..., evn: Eq[Tn$1, Tn$2]])
420434
// : Eq[C[T1$1, ..., Tn$1], C[T1$2, ..., Tn$2]] = Eq
421435
def eqInstance = {
422-
def append(tdef: TypeDef, str: String) = cpy.TypeDef(tdef)(name = tdef.name ++ str)
423-
val leftParams = derivedTparams.map(append(_, "$1"))
424-
val rightParams = derivedTparams.map(append(_, "$2"))
436+
val leftParams = constrTparams.map(derivedTypeParam(_, "$1"))
437+
val rightParams = constrTparams.map(derivedTypeParam(_, "$2"))
425438
val subInstances = (leftParams, rightParams).zipped.map((param1, param2) =>
426439
appliedRef(ref(defn.EqType), List(param1, param2)))
427440
DefDef(
@@ -456,19 +469,16 @@ object desugar {
456469
// For all other classes, the parent is AnyRef.
457470
val companions =
458471
if (isCaseClass) {
459-
def extractType(t: Tree): Tree = t match {
460-
case Apply(t1, _) => extractType(t1)
461-
case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
462-
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
463-
case New(t1) => t1
464-
case t1 => t1
465-
}
466472
// The return type of the `apply` method
467-
val applyResultTpt =
468-
if (isEnumCase)
469-
if (parents.isEmpty) enumClassTypeRef
470-
else parents.map(extractType).reduceLeft(AndTypeTree)
471-
else TypeTree()
473+
val (applyResultTpt, widenDefs) =
474+
if (!isEnumCase)
475+
(TypeTree(), Nil)
476+
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
477+
(enumClassTypeRef, Nil)
478+
else {
479+
val tparams = enumClass.typeParams.map(derivedTypeParam)
480+
enumApplyResult(cdef, parents, tparams, appliedRef(enumClassRef, tparams))
481+
}
472482

473483
val parent =
474484
if (constrTparams.nonEmpty ||
@@ -479,11 +489,13 @@ object desugar {
479489
// todo: also use anyRef if constructor has a dependent method type (or rule that out)!
480490
(constrVparamss :\ (if (isEnumCase) applyResultTpt else classTypeRef)) (
481491
(vparams, restpe) => Function(vparams map (_.tpt), restpe))
492+
def widenedCreatorExpr =
493+
(creatorExpr /: widenDefs)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
482494
val applyMeths =
483495
if (mods is Abstract) Nil
484496
else
485-
DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, creatorExpr)
486-
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil
497+
DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr)
498+
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: widenDefs
487499
val unapplyMeth = {
488500
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
489501
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

+46
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,52 @@ object DesugarEnums {
120120
TypeTree(), creator)
121121
}
122122

123+
/** The return type of an enum case apply method and any widening methods in which
124+
* the apply's right hand side will be wrapped. For parents of the form
125+
*
126+
* extends E(args) with T1(args1) with ... TN(argsN)
127+
*
128+
* and type parameters `tparams` the generated widen method is
129+
*
130+
* def C$to$E[tparams](x$1: E[tparams] with T1 with ... TN) = x$1
131+
*
132+
* @param cdef The case definition
133+
* @param parents The declared parents of the enum case
134+
* @param tparams The type parameters of the enum case
135+
* @param appliedEnumRef The enum class applied to `tparams`.
136+
*/
137+
def enumApplyResult(
138+
cdef: TypeDef,
139+
parents: List[Tree],
140+
tparams: List[TypeDef],
141+
appliedEnumRef: Tree)(implicit ctx: Context): (Tree, List[DefDef]) = {
142+
143+
def extractType(t: Tree): Tree = t match {
144+
case Apply(t1, _) => extractType(t1)
145+
case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
146+
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
147+
case New(t1) => t1
148+
case t1 => t1
149+
}
150+
151+
val parentTypes = parents.map(extractType)
152+
parentTypes.head match {
153+
case parent: RefTree if parent.name == enumClass.name =>
154+
// need a widen method to compute correct type parameters for enum base class
155+
val widenParamType = (appliedEnumRef /: parentTypes.tail)(AndTypeTree)
156+
val widenParam = makeSyntheticParameter(tpt = widenParamType)
157+
val widenDef = DefDef(
158+
name = s"${cdef.name}$$to$$${enumClass.name}".toTermName,
159+
tparams = tparams,
160+
vparamss = (widenParam :: Nil) :: Nil,
161+
tpt = TypeTree(),
162+
rhs = Ident(widenParam.name))
163+
(TypeTree(), widenDef :: Nil)
164+
case _ =>
165+
(parentTypes.reduceLeft(AndTypeTree), Nil)
166+
}
167+
}
168+
123169
/** A pair consisting of
124170
* - the next enum tag
125171
* - scaffolding containing the necessary definitions for singleton enum cases

compiler/src/dotty/tools/dotc/ast/untpd.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
227227
this
228228
}
229229

230+
/** Install the derived type tree as a dependency on `sym` */
231+
def watching(sym: Symbol): this.type = {
232+
pushAttachment(OriginalSymbol, sym)
233+
this
234+
}
235+
230236
/** A hook to ensure that all necessary symbols are completed so that
231237
* OriginalSymbol attachments are propagated to this tree
232238
*/
@@ -240,7 +246,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
240246
* from the symbol in this type. These type trees have marker trees
241247
* TypeRefOfSym or InfoOfSym as their originals.
242248
*/
243-
val References = new Property.Key[List[Tree]]
249+
val References = new Property.Key[List[DerivedTypeTree]]
244250

245251
/** Property key for TypeTrees marked with TypeRefOfSym or InfoOfSym
246252
* which contains the symbol of the original tree from which this

compiler/src/dotty/tools/dotc/typer/Namer.scala

+1-5
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,7 @@ class Namer { typer: Typer =>
228228

229229
/** Record `sym` as the symbol defined by `tree` */
230230
def recordSym(sym: Symbol, tree: Tree)(implicit ctx: Context): Symbol = {
231-
val refs = tree.attachmentOrElse(References, Nil)
232-
if (refs.nonEmpty) {
233-
tree.removeAttachment(References)
234-
refs foreach (_.pushAttachment(OriginalSymbol, sym))
235-
}
231+
for (refs <- tree.removeAttachment(References); ref <- refs) ref.watching(sym)
236232
tree.pushAttachment(SymOfTree, sym)
237233
sym
238234
}

tests/pos/i2663.scala

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
trait Tr
2+
enum Foo[T](x: T) {
3+
case Bar[T](y: T) extends Foo(y)
4+
case Bas[T](y: Int) extends Foo(y)
5+
case Bam[T](y: String) extends Foo(y) with Tr
6+
case Baz[S, T](y: String) extends Foo(y) with Tr
7+
}
8+
object Test {
9+
import Foo._
10+
val bar: Foo[Boolean] = Bar(true)
11+
val bas: Foo[Int] = Bas(1)
12+
val bam: Foo[String] & Tr = Bam("")
13+
val baz: Foo[String] & Tr = Baz("")
14+
}
15+
16+
enum Foo2[S <: T, T](x1: S, x2: T) {
17+
case Bar[T](y: T) extends Foo2(y, y)
18+
case Bas[T](y: Int) extends Foo2(y, y)
19+
case Bam[T](y: String) extends Foo2(y, y) with Tr
20+
case Baz[S, T](y: String) extends Foo2(y, y) with Tr
21+
}
22+
object Test2 {
23+
import Foo2._
24+
val bar: Foo2[Boolean, Boolean] = Bar(true)
25+
val bas: Foo2[Int, Int] = Bas(1)
26+
val bam: Foo2[String, String] & Tr = Bam("")
27+
val baz: Foo2[String, String] & Tr = Baz("")
28+
}
29+
30+

0 commit comments

Comments
 (0)