Skip to content

Fix #2663: More refined handling of enum case apply results #3903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ object desugar {
def derivedType(sym: Symbol)(implicit ctx: Context) = sym.typeRef
}

class DerivedFromParamTree extends DerivedTypeTree {
/** A type tree that computes its type from an existing parameter.
* @param suffix String difference between existing parameter (call it `P`) and parameter owning the
* DerivedTypeTree (call it `O`). We have: `O.name == P.name + suffix`.
*/
class DerivedFromParamTree(suffix: String) extends DerivedTypeTree {

/** Make sure that for all enclosing module classes their companion lasses
* are completed. Reason: We need the constructor of such companion classes to
Expand All @@ -58,24 +62,28 @@ object desugar {

/** Return info of original symbol, where all references to siblings of the
* original symbol (i.e. sibling and original symbol have the same owner)
* are rewired to same-named parameters or accessors in the scope enclosing
* are rewired to like-named* parameters or accessors in the scope enclosing
* the current scope. The current scope is the scope owned by the defined symbol
* itself, that's why we have to look one scope further out. If the resulting
* type is an alias type, dealias it. This is necessary because the
* accessor of a type parameter is a private type alias that cannot be accessed
* from subclasses.
*
* (*) like-named means:
*
* parameter name == reference name ++ suffix
*/
def derivedType(sym: Symbol)(implicit ctx: Context) = {
val relocate = new TypeMap {
val originalOwner = sym.owner
def apply(tp: Type) = tp match {
case tp: NamedType if tp.symbol.exists && (tp.symbol.owner eq originalOwner) =>
val defctx = ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next()
var local = defctx.denotNamed(tp.name).suchThat(_.isParamOrAccessor).symbol
var local = defctx.denotNamed(tp.name ++ suffix).suchThat(_.isParamOrAccessor).symbol
if (local.exists) (defctx.owner.thisType select local).dealias
else {
def msg =
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope}"
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope.toList}"
if (ctx.reporter.errorsReported) ErrorType(msg)
else throw new java.lang.Error(msg)
}
Expand All @@ -88,14 +96,20 @@ object desugar {
}

/** A type definition copied from `tdef` with a rhs typetree derived from it */
def derivedTypeParam(tdef: TypeDef) =
def derivedTypeParam(tdef: TypeDef, suffix: String = ""): TypeDef =
cpy.TypeDef(tdef)(
rhs = new DerivedFromParamTree() withPos tdef.rhs.pos watching tdef)
name = tdef.name ++ suffix,
rhs = new DerivedFromParamTree(suffix).withPos(tdef.rhs.pos).watching(tdef)
)

/** A derived type definition watching `sym` */
def derivedTypeParam(sym: TypeSymbol)(implicit ctx: Context): TypeDef =
TypeDef(sym.name, new DerivedFromParamTree("").watching(sym)).withFlags(TypeParam)

/** A value definition copied from `vdef` with a tpt typetree derived from it */
def derivedTermParam(vdef: ValDef) =
cpy.ValDef(vdef)(
tpt = new DerivedFromParamTree() withPos vdef.tpt.pos watching vdef)
tpt = new DerivedFromParamTree("") withPos vdef.tpt.pos watching vdef)

// ----- Desugar methods -------------------------------------------------

Expand Down Expand Up @@ -317,8 +331,8 @@ object desugar {
}
def anyRef = ref(defn.AnyRefAlias.typeRef)

val derivedTparams = constrTparams map derivedTypeParam
val derivedVparamss = constrVparamss nestedMap derivedTermParam
val derivedTparams = constrTparams.map(derivedTypeParam(_))
val derivedVparamss = constrVparamss.nestedMap(derivedTermParam(_))
val arity = constrVparamss.head.length

val classTycon: Tree = new TypeRefTree // watching is set at end of method
Expand Down Expand Up @@ -419,9 +433,8 @@ object desugar {
// ev1: Eq[T1$1, T1$2], ..., evn: Eq[Tn$1, Tn$2]])
// : Eq[C[T1$1, ..., Tn$1], C[T1$2, ..., Tn$2]] = Eq
def eqInstance = {
def append(tdef: TypeDef, str: String) = cpy.TypeDef(tdef)(name = tdef.name ++ str)
val leftParams = derivedTparams.map(append(_, "$1"))
val rightParams = derivedTparams.map(append(_, "$2"))
val leftParams = constrTparams.map(derivedTypeParam(_, "$1"))
val rightParams = constrTparams.map(derivedTypeParam(_, "$2"))
val subInstances = (leftParams, rightParams).zipped.map((param1, param2) =>
appliedRef(ref(defn.EqType), List(param1, param2)))
DefDef(
Expand Down Expand Up @@ -456,19 +469,16 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
def extractType(t: Tree): Tree = t match {
case Apply(t1, _) => extractType(t1)
case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
case New(t1) => t1
case t1 => t1
}
// The return type of the `apply` method
val applyResultTpt =
if (isEnumCase)
if (parents.isEmpty) enumClassTypeRef
else parents.map(extractType).reduceLeft(AndTypeTree)
else TypeTree()
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else {
val tparams = enumClass.typeParams.map(derivedTypeParam)
enumApplyResult(cdef, parents, tparams, appliedRef(enumClassRef, tparams))
}

val parent =
if (constrTparams.nonEmpty ||
Expand All @@ -479,11 +489,13 @@ object desugar {
// todo: also use anyRef if constructor has a dependent method type (or rule that out)!
(constrVparamss :\ (if (isEnumCase) applyResultTpt else classTypeRef)) (
(vparams, restpe) => Function(vparams map (_.tpt), restpe))
def widenedCreatorExpr =
(creatorExpr /: widenDefs)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods is Abstract) Nil
else
DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, creatorExpr)
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil
DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr)
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: widenDefs
val unapplyMeth = {
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
Expand Down
46 changes: 46 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,52 @@ object DesugarEnums {
TypeTree(), creator)
}

/** The return type of an enum case apply method and any widening methods in which
* the apply's right hand side will be wrapped. For parents of the form
*
* extends E(args) with T1(args1) with ... TN(argsN)
*
* and type parameters `tparams` the generated widen method is
*
* def C$to$E[tparams](x$1: E[tparams] with T1 with ... TN) = x$1
*
* @param cdef The case definition
* @param parents The declared parents of the enum case
* @param tparams The type parameters of the enum case
* @param appliedEnumRef The enum class applied to `tparams`.
*/
def enumApplyResult(
cdef: TypeDef,
parents: List[Tree],
tparams: List[TypeDef],
appliedEnumRef: Tree)(implicit ctx: Context): (Tree, List[DefDef]) = {

def extractType(t: Tree): Tree = t match {
case Apply(t1, _) => extractType(t1)
case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
case New(t1) => t1
case t1 => t1
}

val parentTypes = parents.map(extractType)
parentTypes.head match {
case parent: RefTree if parent.name == enumClass.name =>
// need a widen method to compute correct type parameters for enum base class
val widenParamType = (appliedEnumRef /: parentTypes.tail)(AndTypeTree)
val widenParam = makeSyntheticParameter(tpt = widenParamType)
val widenDef = DefDef(
name = s"${cdef.name}$$to$$${enumClass.name}".toTermName,
tparams = tparams,
vparamss = (widenParam :: Nil) :: Nil,
tpt = TypeTree(),
rhs = Ident(widenParam.name))
(TypeTree(), widenDef :: Nil)
case _ =>
(parentTypes.reduceLeft(AndTypeTree), Nil)
}
}

/** A pair consisting of
* - the next enum tag
* - scaffolding containing the necessary definitions for singleton enum cases
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
this
}

/** Install the derived type tree as a dependency on `sym` */
def watching(sym: Symbol): this.type = {
pushAttachment(OriginalSymbol, sym)
this
}

/** A hook to ensure that all necessary symbols are completed so that
* OriginalSymbol attachments are propagated to this tree
*/
Expand All @@ -240,7 +246,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
* from the symbol in this type. These type trees have marker trees
* TypeRefOfSym or InfoOfSym as their originals.
*/
val References = new Property.Key[List[Tree]]
val References = new Property.Key[List[DerivedTypeTree]]

/** Property key for TypeTrees marked with TypeRefOfSym or InfoOfSym
* which contains the symbol of the original tree from which this
Expand Down
6 changes: 1 addition & 5 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,7 @@ class Namer { typer: Typer =>

/** Record `sym` as the symbol defined by `tree` */
def recordSym(sym: Symbol, tree: Tree)(implicit ctx: Context): Symbol = {
val refs = tree.attachmentOrElse(References, Nil)
if (refs.nonEmpty) {
tree.removeAttachment(References)
refs foreach (_.pushAttachment(OriginalSymbol, sym))
}
for (refs <- tree.removeAttachment(References); ref <- refs) ref.watching(sym)
tree.pushAttachment(SymOfTree, sym)
sym
}
Expand Down
30 changes: 30 additions & 0 deletions tests/pos/i2663.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
trait Tr
enum Foo[T](x: T) {
case Bar[T](y: T) extends Foo(y)
case Bas[T](y: Int) extends Foo(y)
case Bam[T](y: String) extends Foo(y) with Tr
case Baz[S, T](y: String) extends Foo(y) with Tr
}
object Test {
import Foo._
val bar: Foo[Boolean] = Bar(true)
val bas: Foo[Int] = Bas(1)
val bam: Foo[String] & Tr = Bam("")
val baz: Foo[String] & Tr = Baz("")
}

enum Foo2[S <: T, T](x1: S, x2: T) {
case Bar[T](y: T) extends Foo2(y, y)
case Bas[T](y: Int) extends Foo2(y, y)
case Bam[T](y: String) extends Foo2(y, y) with Tr
case Baz[S, T](y: String) extends Foo2(y, y) with Tr
}
object Test2 {
import Foo2._
val bar: Foo2[Boolean, Boolean] = Bar(true)
val bas: Foo2[Int, Int] = Bas(1)
val bam: Foo2[String, String] & Tr = Bam("")
val baz: Foo2[String, String] & Tr = Baz("")
}