Skip to content

Fix #23224: Optimize simple tuple extraction #23373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,27 @@ object desugar {
sel
end match

case class TuplePatternInfo(arity: Int, varNum: Int, wildcardNum: Int)
object TuplePatternInfo:
def apply(pat: Tree)(using Context): TuplePatternInfo = pat match
case Tuple(pats) =>
var arity = 0
var varNum = 0
var wildcardNum = 0
pats.foreach: p =>
arity += 1
p match
case id: Ident if !isBackquoted(id) =>
if id.name.isVarPattern then
varNum += 1
if id.name == nme.WILDCARD then
wildcardNum += 1
case _ =>
TuplePatternInfo(arity, varNum, wildcardNum)
case _ =>
TuplePatternInfo(-1, -1, -1)
end TuplePatternInfo

/** If `pat` is a variable pattern,
*
* val/var/lazy val p = e
Expand Down Expand Up @@ -1483,30 +1504,47 @@ object desugar {
|please bind to an identifier and use an alias given.""", bind)
false

def isTuplePattern(arity: Int): Boolean = pat match {
case Tuple(pats) if pats.size == arity =>
pats.forall(isVarPattern)
case _ => false
}

val isMatchingTuple: Tree => Boolean = {
case Tuple(es) => isTuplePattern(es.length) && !hasNamedArg(es)
case _ => false
}
val tuplePatternInfo = TuplePatternInfo(pat)

// When desugaring a PatDef in general, we use pattern matching on the rhs
// and collect the variable values in a tuple, then outside the match,
// we destructure the tuple to get the individual variables.
// We can achieve two kinds of tuple optimizations if the pattern is a tuple
// of simple variables or wildcards:
// 1. Full optimization:
// If the rhs is known to produce a literal tuple of the same arity,
// we can directly fetch the values from the tuple.
// For example: `val (x, y) = if ... then (1, "a") else (2, "b")` becomes
// `val $1$ = if ...; val x = $1$._1; val y = $1$._2`.
// 2. Partial optimization:
// If the rhs can be typed as a tuple and matched with correct arity, we can
// return the tuple itself in the case if there are no more than one variable
// in the pattern, or return the the value if there is only one variable.

val fullTupleOptimizable =
val isMatchingTuple: Tree => Boolean = {
case Tuple(es) => tuplePatternInfo.varNum == es.length && !hasNamedArg(es)
case _ => false
}
tuplePatternInfo.arity > 0
&& tuplePatternInfo.arity == tuplePatternInfo.varNum
&& forallResults(rhs, isMatchingTuple)

// We can only optimize `val pat = if (...) e1 else e2` if:
// - `e1` and `e2` are both tuples of arity N
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
val partialTupleOptimizable =
tuplePatternInfo.arity > 0
&& tuplePatternInfo.arity == tuplePatternInfo.varNum
// We exclude the case where there is only one variable,
// because it should be handled by `makeTuple` directly.
&& tuplePatternInfo.wildcardNum < tuplePatternInfo.arity - 1

val inAliasGenerator = original match
case _: GenAlias => true
case _ => false

val vars =
if (tupleOptimizable) // include `_`
val vars: List[VarInfo] =
if fullTupleOptimizable || partialTupleOptimizable then // include `_`
pat match
case Tuple(pats) => pats.map { case id: Ident => id -> TypeTree() }
case Tuple(pats) => pats.map { case id: Ident => (id, TypeTree()) }
else
getVariables(
tree = pat,
Expand All @@ -1517,12 +1555,27 @@ object desugar {
errorOnGivenBinding
) // no `_`

val ids = for ((named, _) <- vars) yield Ident(named.name)
val ids = for ((named, tpt) <- vars) yield Ident(named.name)

val matchExpr =
if (tupleOptimizable) rhs
if fullTupleOptimizable then rhs
else
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids).withAttachment(ForArtifact, ()))
val caseDef =
if partialTupleOptimizable then
val tmpTuple = UniqueName.fresh()
// Replace all variables with wildcards in the pattern
val pat1 = pat match
case Tuple(pats) =>
val wildcardPats = pats.map(p => Ident(nme.WILDCARD).withSpan(p.span))
Tuple(wildcardPats).withSpan(pat.span)
CaseDef(
Bind(tmpTuple, pat1),
EmptyTree,
Ident(tmpTuple).withAttachment(ForArtifact, ())
)
else CaseDef(pat, EmptyTree, makeTuple(ids).withAttachment(ForArtifact, ()))
Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)

vars match {
case Nil if !mods.is(Lazy) =>
matchExpr
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,16 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
}

/** Checks whether predicate `p` is true for all result parts of this expression,
* where we zoom into Ifs, Matches, and Blocks.
* where we zoom into Ifs, Matches, Tries, and Blocks.
*/
def forallResults(tree: Tree, p: Tree => Boolean): Boolean = tree match {
def forallResults(tree: Tree, p: Tree => Boolean): Boolean = tree match
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
case Match(_, cases) => cases.forall(c => forallResults(c.body, p))
case Try(_, cases, finalizer) =>
cases.forall(c => forallResults(c.body, p))
&& (finalizer.isEmpty || forallResults(finalizer, p))
case Block(_, expr) => forallResults(expr, p)
case _ => p(tree)
}

/** The tree stripped of the possibly nested applications (term and type).
* The original tree if it's not an application.
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,8 @@ trait Applications extends Compatibility {
if selType <:< unapplyArgType then
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
fullyDefinedType(unapplyArgType, "pattern selector", tree.srcPos)
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
if selType.isBottomType then unapplyArgType
else selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
else
if !ctx.mode.is(Mode.InTypeTest) then
checkMatchable(selType, tree.srcPos, pattern = true)
Expand All @@ -1708,7 +1709,7 @@ trait Applications extends Compatibility {
val unapplyPatterns = UnapplyArgs(unapplyApp.tpe, unapplyFn, unadaptedArgs, tree.srcPos)
.typedPatterns(qual, this)
val result = assignType(cpy.UnApply(tree)(newUnapplyFn, unapplyImplicits(dummyArg, unapplyApp), unapplyPatterns), ownType)
if (ownType.stripped eq selType.stripped) || ownType.isError then result
if (ownType.stripped eq selType.stripped) || selType.isBottomType || ownType.isError then result
else tryWithTypeTest(Typed(result, TypeTree(ownType)), selType)
case tp =>
val unapplyErr = if (tp.isError) unapplyFn else notAnExtractor(unapplyFn)
Expand Down
19 changes: 16 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2774,6 +2774,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if !isFullyDefined(pt, ForceDegree.all) then
return errorTree(tree, em"expected type of $tree is not fully defined")
val body1 = typed(tree.body, pt)

// If the body is a named tuple pattern, we need to use pt for symbol type,
// because the desugared body is a regular tuple unapply.
def isNamedTuplePattern =
ctx.mode.is(Mode.Pattern)
&& pt.dealias.isNamedTupleType
&& tree.body.match
case untpd.Tuple((_: NamedArg) :: _) => true
case _ => false

body1 match {
case UnApply(fn, Nil, arg :: Nil)
if fn.symbol.exists && (fn.symbol.owner.derivesFrom(defn.TypeTestClass) || fn.symbol.owner == defn.ClassTagClass) && !body1.tpe.isError =>
Expand All @@ -2799,10 +2809,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
body1.isInstanceOf[RefTree] && !isWildcardArg(body1)
|| body1.isInstanceOf[Literal]
val symTp =
if isStableIdentifierOrLiteral || pt.dealias.isNamedTupleType then pt
// need to combine tuple element types with expected named type
if isStableIdentifierOrLiteral || isNamedTuplePattern then pt
else if isWildcardStarArg(body1)
|| pt == defn.ImplicitScrutineeTypeRef
|| pt.isBottomType
|| body1.tpe <:< pt // There is some strange interaction with gadt matching.
// and implicit scopes.
// run/t2755.scala fails to compile if this subtype test is omitted
Expand Down Expand Up @@ -3542,7 +3552,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedTuple(tree: untpd.Tuple, pt: Type)(using Context): Tree =
val tree1 = desugar.tuple(tree, pt).withAttachmentsFrom(tree)
checkDeprecatedAssignmentSyntax(tree)
if tree1 ne tree then typed(tree1, pt)
if tree1 ne tree then
val t = typed(tree1, pt)
// println(i"typedTuple: ${t} , ${t.tpe}")
t
else
val arity = tree.trees.length
val pts = pt.stripNamedTuple.tupleElementTypes match
Expand Down
12 changes: 6 additions & 6 deletions tests/neg/i7294.check
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
-- [E007] Type Mismatch Error: tests/neg/i7294.scala:7:15 --------------------------------------------------------------
-- [E007] Type Mismatch Error: tests/neg/i7294.scala:7:18 --------------------------------------------------------------
7 | case x: T => x.g(10) // error
| ^
| Found: (x : Nothing)
| Required: ?{ g: ? }
| Note that implicit conversions were not tried because the result of an implicit conversion
| must be more specific than ?{ g: [applied to (10) returning T] }
| ^^^^^^^
| Found: Any
| Required: T
|
| where: T is a type in given instance f with bounds <: foo.Foo
|
| longer explanation available when compiling with `-explain`
43 changes: 43 additions & 0 deletions tests/pos/simple-tuple-extract.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

class Test:
def f1: (Int, String, AnyRef) = (1, "2", "3")
def f2: (x: Int, y: String) = (0, "y")

def test1 =
val (a, b, c) = f1
// Desugared to:
// val $2$: (Int, String, AnyRef) =
// this.f1:(Int, String, AnyRef) @unchecked match
// {
// case $1$ @ Tuple3.unapply[Int, String, Object](_, _, _) =>
// $1$:(Int, String, AnyRef)
// }
// val a: Int = $2$._1
// val b: String = $2$._2
// val c: AnyRef = $2$._3
a + b.length() + c.toString.length()

// This pattern will not be optimized:
// val (a1, b1, c1: String) = f1

def test2 =
val (_, b, c) = f1
b.length() + c.toString.length()

val (a2, _, c2) = f1
a2 + c2.toString.length()

val (a3, _, _) = f1
a3 + 1

def test3 =
val (_, b, _) = f1
b.length() + 1

def test4 =
val (x, y) = f2
x + y.length()

def test5 =
val (_, b) = f2
b.length() + 1
Loading