From 2ccba7c841d2cb43b165098258b4ccf969b93a16 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Mon, 6 Jul 2020 15:23:53 +0200 Subject: [PATCH 1/2] Disallow curried dependent context function types --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 4 ++-- compiler/src/dotty/tools/dotc/typer/Namer.scala | 15 ++++++++------- compiler/src/dotty/tools/dotc/typer/Typer.scala | 16 ++++++++++++---- tests/neg/curried-dependent-ift.scala | 17 +++++++++++++++++ tests/neg/i4668.scala | 2 +- 5 files changed, 40 insertions(+), 14 deletions(-) create mode 100644 tests/neg/curried-dependent-ift.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 36f71e6df03d..9557c8ae4268 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -559,7 +559,7 @@ object desugar { val copiedAccessFlags = if migrateTo3 then EmptyFlags else AccessFlags // Methods to add to a case class C[..](p1: T1, ..., pN: Tn)(moreParams) - // def _1: T1 = this.p1 + // def _1: T1 = this.p1 // ... // def _N: TN = this.pN (unless already given as valdef or parameterless defdef) // def copy(p1: T1 = p1: @uncheckedVariance, ..., @@ -572,7 +572,7 @@ object desugar { val caseClassMeths = { def syntheticProperty(name: TermName, tpt: Tree, rhs: Tree) = DefDef(name, Nil, Nil, tpt, rhs).withMods(synthetic) - + def productElemMeths = val caseParams = derivedVparamss.head.toArray val selectorNamesInBody = normalizedBody.collect { diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index c2da3d90b668..a6345ca32ab6 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -246,7 +246,10 @@ class Namer { typer: Typer => val xtree = expanded(tree) xtree.getAttachment(TypedAhead) match { case Some(ttree) => ttree.symbol - case none => xtree.attachment(SymOfTree) + case none => + xtree.getAttachment(SymOfTree) match + case Some(sym) => sym + case _ => throw IllegalArgumentException(i"$xtree does not have a symbol") } } @@ -443,14 +446,11 @@ class Namer { typer: Typer => /** If `sym` exists, enter it in effective scope. Check that * package members are not entered twice in the same run. */ - def enterSymbol(sym: Symbol)(using Context): Symbol = { + def enterSymbol(sym: Symbol)(using Context): Unit = // We do not enter Scala 2 macros defined in Scala 3 as they have an equivalent Scala 3 inline method. - if (sym.exists && !sym.isScala2MacroInScala3) { + if sym.exists && !sym.isScala2MacroInScala3 then typr.println(s"entered: $sym in ${ctx.owner}") ctx.enter(sym) - } - sym - } /** Create package if it does not yet exist. */ private def createPackageSymbol(pid: RefTree)(using Context): Symbol = { @@ -539,7 +539,8 @@ class Namer { typer: Typer => case imp: Import => ctx.importContext(imp, createSymbol(imp)) case mdef: DefTree => - val sym = enterSymbol(createSymbol(mdef)) + val sym = createSymbol(mdef) + enterSymbol(sym) setDocstring(sym, origStat) addEnumConstants(mdef, sym) ctx diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 12051ad2433e..ef5e706c4804 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -936,7 +936,7 @@ class Typer extends Namer * def double(x: Char): String = s"$x$x" * "abc" flatMap double */ - private def decomposeProtoFunction(pt: Type, defaultArity: Int)(using Context): (List[Type], untpd.Tree) = { + private def decomposeProtoFunction(pt: Type, defaultArity: Int, tree: untpd.Tree)(using Context): (List[Type], untpd.Tree) = { def typeTree(tp: Type) = tp match { case _: WildcardType => untpd.TypeTree() case _ => untpd.TypeTree(tp) @@ -947,7 +947,15 @@ class Typer extends Namer newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds) case _ => mapOver(t) } - pt.stripTypeVar.dealias match { + val pt1 = pt.stripTypeVar.dealias + if (pt1 ne pt1.dropDependentRefinement) + && defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType) + then + ctx.error( + i"""Implementation restriction: Expected result type $pt1 + |is a curried dependent context function type. Such types are not yet supported.""", + tree.sourcePos) + pt1 match { case pt1 if defn.isNonRefinedFunction(pt1) => // if expected parameter type(s) are wildcards, approximate from below. // if expected result type is a wildcard, approximate from above. @@ -960,7 +968,7 @@ class Typer extends Namer else typeTree(restpe)) case tp: TypeParamRef => - decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity) + decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree) case _ => (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) } @@ -1251,7 +1259,7 @@ class Typer extends Namer typedMatchFinish(tree, tpd.EmptyTree, defn.ImplicitScrutineeTypeRef, cases1, pt) } else { - val (protoFormals, _) = decomposeProtoFunction(pt, 1) + val (protoFormals, _) = decomposeProtoFunction(pt, 1, tree) val checkMode = if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None else desugar.MatchCheck.Exhaustive diff --git a/tests/neg/curried-dependent-ift.scala b/tests/neg/curried-dependent-ift.scala new file mode 100644 index 000000000000..8f8582a4a120 --- /dev/null +++ b/tests/neg/curried-dependent-ift.scala @@ -0,0 +1,17 @@ +trait Ctx1: + type T + val x: T + val y: T + +trait Ctx2: + type T + val x: T + val y: T + +trait A +trait B + +def h(x: Boolean): A ?=> B ?=> (A, B) = + (summon[A], summon[B]) // OK + +def g(x: Boolean): (c1: Ctx1) ?=> Ctx2 ?=> (c1.T, Ctx2) = ??? // error diff --git a/tests/neg/i4668.scala b/tests/neg/i4668.scala index 468de58b7d9f..6ea702f47f09 100644 --- a/tests/neg/i4668.scala +++ b/tests/neg/i4668.scala @@ -8,4 +8,4 @@ trait Functor[F[_]] { def map[A,B](x: F[A])(f: A => B): F[B] } object Functor { implicit object listFun extends Functor[List] { def map[A,B](ls: List[A])(f: A => B) = ls.map(f) } } val map: (A:Type,B:Type,F:Type1) ?=> (Functor[F.T]) ?=> (F.T[A.T]) => (A.T => B.T) => F.T[B.T] = - fun ?=> x => f => fun.map(x)(f) // error // error // error: Missing parameter type + fun ?=> x => f => fun.map(x)(f) // error \ No newline at end of file From 3d4260671de82920dde21eaaf86cd9eb880316fb Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Mon, 6 Jul 2020 17:38:08 +0200 Subject: [PATCH 2/2] Implement returns in methods with context function result types --- .../src/dotty/tools/dotc/typer/Typer.scala | 45 ++++++++++++++----- tests/neg/curried-dependent-ift.scala | 4 +- tests/run/ift-return.check | 2 + tests/run/ift-return.scala | 23 ++++++++++ 4 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 tests/run/ift-return.check create mode 100644 tests/run/ift-return.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index ef5e706c4804..2587e3d189f6 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1129,7 +1129,7 @@ class Typer extends Namer case _ => } - val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length) + val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree) /** The inferred parameter type for a parameter in a lambda that does * not have an explicit type given. @@ -1445,17 +1445,40 @@ class Typer extends Namer } def typedReturn(tree: untpd.Return)(using Context): Return = { + + /** If `pt` is a context function type, its return type. If the CFT + * is dependent, instantiate with the parameters of the associated + * anonymous function. + * @param paramss the parameters of the anonymous functions + * enclosing the return expression + */ + def instantiateCFT(pt: Type, paramss: => List[List[Symbol]]): Type = + val ift = defn.asContextFunctionType(pt) + if ift.exists then + ift.nonPrivateMember(nme.apply).info match + case appType: MethodType => + instantiateCFT(appType.instantiate(paramss.head.map(_.termRef)), paramss.tail) + else pt + def returnProto(owner: Symbol, locals: Scope): Type = if (owner.isConstructor) defn.UnitType - else owner.info match { - case info: PolyType => - val tparams = locals.toList.takeWhile(_ is TypeParam) - assert(info.paramNames.length == tparams.length, - i"return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, %") - info.instantiate(tparams.map(_.typeRef)).finalResultType - case info => - info.finalResultType - } + else + val rt = owner.info match + case info: PolyType => + val tparams = locals.toList.takeWhile(_ is TypeParam) + assert(info.paramNames.length == tparams.length, + i"return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, %") + info.instantiate(tparams.map(_.typeRef)).finalResultType + case info => + info.finalResultType + def iftParamss = ctx.owner.ownersIterator + .filter(_.is(Method, butNot = Accessor)) + .takeWhile(_.isAnonymousFunction) + .toList + .reverse + .map(_.paramSymss.head) + instantiateCFT(rt, iftParamss) + def enclMethInfo(cx: Context): (Tree, Type) = { val owner = cx.owner if (owner.isType) { @@ -3147,7 +3170,7 @@ class Typer extends Namer def isContextFunctionRef(wtp: Type): Boolean = wtp match { case RefinedType(parent, nme.apply, _) => - isContextFunctionRef(parent) // apply refinements indicate a dependent IFT + isContextFunctionRef(parent) // apply refinements indicate a dependent CFT case _ => val underlying = wtp.underlyingClassRef(refinementOK = false) // other refinements are not OK defn.isContextFunctionClass(underlying.classSymbol) diff --git a/tests/neg/curried-dependent-ift.scala b/tests/neg/curried-dependent-ift.scala index 8f8582a4a120..359514505613 100644 --- a/tests/neg/curried-dependent-ift.scala +++ b/tests/neg/curried-dependent-ift.scala @@ -14,4 +14,6 @@ trait B def h(x: Boolean): A ?=> B ?=> (A, B) = (summon[A], summon[B]) // OK -def g(x: Boolean): (c1: Ctx1) ?=> Ctx2 ?=> (c1.T, Ctx2) = ??? // error +def g(x: Boolean): (c1: Ctx1) ?=> Ctx2 ?=> (c1.T, Ctx2) = + return ??? // error + ??? diff --git a/tests/run/ift-return.check b/tests/run/ift-return.check new file mode 100644 index 000000000000..5494f0d4e3fb --- /dev/null +++ b/tests/run/ift-return.check @@ -0,0 +1,2 @@ +(22,abc) +(22,def) diff --git a/tests/run/ift-return.scala b/tests/run/ift-return.scala new file mode 100644 index 000000000000..b49f4c647ee0 --- /dev/null +++ b/tests/run/ift-return.scala @@ -0,0 +1,23 @@ +trait A: + val x: Int + +trait Ctx: + type T + val x: T + val y: T + +def f(x: Boolean): A ?=> (c: Ctx) ?=> (Int, c.T) = + if x then return (summon[A].x, summon[Ctx].x) + (summon[A].x, summon[Ctx].y) + +@main def Test = + given A: + val x = 22 + given Ctx: + type T = String + val x = "abc" + val y = "def" + + println(f(true)) + println(f(false)) +