Skip to content

Fix parameter untupling #14816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ object desugar {
*/
val MultiLineInfix: Property.Key[Unit] = Property.StickyKey()

/** An attachment key to indicate that a ValDef originated from parameter untupling.
*/
val UntupledParam: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
Expand Down Expand Up @@ -1426,7 +1430,9 @@ object desugar {
val vdefs =
params.zipWithIndex.map {
case (param, idx) =>
DefDef(param.name, Nil, param.tpt, selector(idx)).withSpan(param.span)
ValDef(param.name, param.tpt, selector(idx))
.withSpan(param.span)
.withAttachment(UntupledParam, ())
}
Function(param :: Nil, Block(vdefs, body))
}
Expand Down
21 changes: 12 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,21 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}

/** A tree representing the same reference as the given type */
def ref(tp: NamedType)(using Context): Tree =
def ref(tp: NamedType, needLoad: Boolean = true)(using Context): Tree =
if (tp.isType) TypeTree(tp)
else if (prefixIsElidable(tp)) Ident(tp)
else if (tp.symbol.is(Module) && ctx.owner.isContainedIn(tp.symbol.moduleClass))
followOuterLinks(This(tp.symbol.moduleClass.asClass))
else if (tp.symbol hasAnnotation defn.ScalaStaticAnnot)
Ident(tp)
else {
else
val pre = tp.prefix
if (pre.isSingleton) followOuterLinks(singleton(pre.dealias)).select(tp)
else Select(TypeTree(pre), tp)
}
if (pre.isSingleton) followOuterLinks(singleton(pre.dealias, needLoad)).select(tp)
else
val res = Select(TypeTree(pre), tp)
if needLoad && !res.symbol.isStatic then
throw new TypeError(em"cannot establish a reference to $res")
res

def ref(sym: Symbol)(using Context): Tree =
ref(NamedType(sym.owner.thisType, sym.name, sym.denot))
Expand All @@ -428,11 +431,11 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
t
}

def singleton(tp: Type)(using Context): Tree = tp.dealias match {
case tp: TermRef => ref(tp)
def singleton(tp: Type, needLoad: Boolean = true)(using Context): Tree = tp.dealias match {
case tp: TermRef => ref(tp, needLoad)
case tp: ThisType => This(tp.cls)
case tp: SkolemType => singleton(tp.narrow)
case SuperType(qual, _) => singleton(qual)
case tp: SkolemType => singleton(tp.narrow, needLoad)
case SuperType(qual, _) => singleton(qual, needLoad)
case ConstantType(value) => Literal(value)
}

Expand Down
12 changes: 6 additions & 6 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ object Completion {
* @param content The source content that we'll check the positions for the prefix
* @param start The start position we'll start to look for the prefix at
* @param end The end position we'll look for the prefix at
* @return Either the full prefix including the ` or an empty string
* @return Either the full prefix including the ` or an empty string
*/
private def checkBacktickPrefix(content: Array[Char], start: Int, end: Int): String =
private def checkBacktickPrefix(content: Array[Char], start: Int, end: Int): String =
content.lift(start) match
case Some(char) if char == '`' =>
content.slice(start, end).mkString
Expand All @@ -111,7 +111,7 @@ object Completion {
// Foo.`se<TAB> will result in Select(Ident(Foo), <error>)
case (select: untpd.Select) :: _ if select.name == nme.ERROR =>
checkBacktickPrefix(select.source.content(), select.nameSpan.start, select.span.end)

// import scala.util.chaining.`s<TAB> will result in a Ident(<error>)
case (ident: untpd.Ident) :: _ if ident.name == nme.ERROR =>
checkBacktickPrefix(ident.source.content(), ident.span.start, ident.span.end)
Expand Down Expand Up @@ -177,14 +177,14 @@ object Completion {
// https://github.com/com-lihaoyi/Ammonite/blob/73a874173cd337f953a3edc9fb8cb96556638fdd/amm/util/src/main/scala/ammonite/util/Model.scala
private def needsBacktick(s: String) =
val chunks = s.split("_", -1)

val validChunks = chunks.zipWithIndex.forall { case (chunk, index) =>
chunk.forall(Chars.isIdentifierPart) ||
(chunk.forall(Chars.isOperatorPart) &&
index == chunks.length - 1 &&
!(chunks.lift(index - 1).contains("") && index - 1 == 0))
}

val validStart =
Chars.isIdentifierStart(s(0)) || chunks(0).forall(Chars.isOperatorPart)

Expand Down Expand Up @@ -312,7 +312,7 @@ object Completion {

/** Replaces underlying type with reduced one, when it's MatchType */
def reduceUnderlyingMatchType(qual: Tree)(using Context): Tree=
qual.tpe.widen match
qual.tpe.widen match
case ctx.typer.MatchTypeInDisguise(mt) => qual.withType(mt)
case _ => qual

Expand Down
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dotty.tools.dotc
package transform

import dotty.tools.dotc.ast.{Trees, tpd, untpd}
import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar}
import scala.collection.mutable
import core._
import dotty.tools.dotc.typer.Checking
Expand Down Expand Up @@ -255,6 +255,14 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
if tree.symbol.is(ConstructorProxy) then
report.error(em"constructor proxy ${tree.symbol} cannot be used as a value", tree.srcPos)

def checkStableSelection(tree: Tree)(using Context): Unit =
def check(qual: Tree) =
if !qual.tpe.isStable then
report.error(em"Parameter untupling cannot be used for call-by-name parameters", tree.srcPos)
tree match
case Select(qual, _) => check(qual) // simple select _n
case Apply(TypeApply(Select(qual, _), _), _) => check(qual) // generic select .apply[T](n)

override def transform(tree: Tree)(using Context): Tree =
try tree match {
// TODO move CaseDef case lower: keep most probable trees first for performance
Expand Down Expand Up @@ -356,6 +364,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
case tree: ValDef =>
checkErasedDef(tree)
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
checkStableSelection(tree.rhs)
processValOrDefDef(super.transform(tree1))
case tree: DefDef =>
checkErasedDef(tree)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ trait Applications extends Compatibility {
def fail(msg: Message): Unit =
ok = false
def appPos: SrcPos = NoSourcePosition
@threadUnsafe lazy val normalizedFun: Tree = ref(methRef)
@threadUnsafe lazy val normalizedFun: Tree = ref(methRef, needLoad = false)
init()
}

Expand Down Expand Up @@ -2268,7 +2268,7 @@ trait Applications extends Compatibility {
case TypeApply(fun, args) => TypeApply(replaceCallee(fun, replacement), args)
case _ => replacement

val methodRefTree = ref(methodRef)
val methodRefTree = ref(methodRef, needLoad = false)
val truncatedSym = methodRef.symbol.asTerm.copy(info = truncateExtension(methodRef.info))
val truncatedRefTree = untpd.TypedSplice(ref(truncatedSym)).withSpan(receiver.span)
val newCtx = ctx.fresh.setNewScope.setReporter(new reporting.ThrowingReporter(ctx.reporter))
Expand Down
7 changes: 7 additions & 0 deletions tests/neg/function-arity-2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object Test {

class T[A] { def foo(f: (=> A) => Int) = f(???) }

def main(args: Array[String]): Unit =
new T[(Int, Int)].foo((x, y) => 0) // error // error Parameter untupling cannot be used for call-by-name parameters (twice)
}
3 changes: 3 additions & 0 deletions tests/neg/i14783.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
object Test:
def foo(f: (=> (Int, Int)) => Int) = ???
foo((a, b) => a + b) // error // error
51 changes: 51 additions & 0 deletions tests/pos/i14783.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
object Wart:
def bar(using c: Ctx)(ws: List[Wrap[c.type]]): Unit =
ws.zipWithIndex.foreach { (w, _) => w.x.foo }

trait Wrap[C <: Ctx & Singleton]:
val ctx: C
def x: ctx.inner.X

trait Ctx:
object inner:
type X
extension (self: X) def foo: Int = ???


object WartInspector:
def myWartTraverser: WartTraverser = ???
def inspect(using q: Quotes)(tastys: List[Tasty[q.type]]): Unit = {
val universe: WartUniverse.Aux[q.type] = WartUniverse(q)
val traverser = myWartTraverser.get(universe)
tastys.zipWithIndex.foreach { (tasty, index) =>
val tree = tasty.ast
traverser.traverseTree(tree)(tree.symbol)
}
}

object WartUniverse:
type Aux[X <: Quotes] = WartUniverse { type Q = X }
def apply[Q <: Quotes](quotes: Q): Aux[Q] = ???


abstract class WartUniverse:
type Q <: Quotes
val quotes: Q
abstract class Traverser extends quotes.reflect.TreeTraverser


abstract class WartTraverser:
def get(u: WartUniverse): u.Traverser

trait Tasty[Q <: Quotes & Singleton]:
val quotes: Q
def path: String
def ast: quotes.reflect.Tree

trait Quotes:
object reflect:
type Tree
extension (self: Tree) def symbol: Symbol = ???
type Symbol
trait TreeTraverser:
def traverseTree(tree: Tree)(symbol: Symbol): Unit
2 changes: 1 addition & 1 deletion tests/run/function-arity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ object Test {

def main(args: Array[String]): Unit = {
new T[(Int, Int)].foo((ii) => 0)
new T[(Int, Int)].foo((x, y) => 0) // check that this does not run into ???
//new T[(Int, Int)].foo((x, y) => 0) // not allowed anymore
}
}