Skip to content

Commit 11c65aa

Browse files
committed
Fix pattern matching for get matches
1 parent 9076944 commit 11c65aa

File tree

6 files changed

+146
-88
lines changed

6 files changed

+146
-88
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ object PatternMatcher {
379379
assert(isGetMatch(unappType))
380380
val argsPlan = {
381381
val get = ref(unappResult).select(nme.get, _.info.isParameterless)
382-
val arity = productArity(get.tpe, unapp.srcPos)
382+
val arity = productArity(get.tpe.stripNamedTuple, unapp.srcPos)
383383
if (isUnapplySeq)
384384
letAbstract(get) { getResult =>
385385
if (arity > 0) unapplyProductSeqPlan(getResult, args, arity)
@@ -389,7 +389,7 @@ object PatternMatcher {
389389
letAbstract(get) { getResult =>
390390
val selectors =
391391
if (args.tail.isEmpty) ref(getResult) :: Nil
392-
else productSelectors(get.tpe).map(ref(getResult).select(_))
392+
else productSelectors(getResult.info).map(ref(getResult).select(_))
393393
matchArgsPlan(selectors, args, onSuccess)
394394
}
395395
}

compiler/src/dotty/tools/dotc/typer/Applications.scala

+96-55
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import Names.*
1818
import StdNames.*
1919
import ContextOps.*
2020
import NameKinds.DefaultGetterName
21+
import Typer.tryEither
2122
import ProtoTypes.*
2223
import Inferencing.*
2324
import reporting.*
@@ -134,37 +135,37 @@ object Applications {
134135
sels.takeWhile(_.exists).toList
135136
}
136137

137-
def getUnapplySelectors(tp: Type, args: List[untpd.Tree], pos: SrcPos)(using Context): List[Type] =
138-
if (args.length > 1 && !(tp.derivesFrom(defn.SeqClass))) {
139-
val sels = productSelectorTypes(tp, pos)
140-
if (sels.length == args.length) sels
141-
else tp :: Nil
142-
}
143-
else tp :: Nil
144-
145138
def productSeqSelectors(tp: Type, argsNum: Int, pos: SrcPos)(using Context): List[Type] = {
146139
val selTps = productSelectorTypes(tp, pos)
147140
val arity = selTps.length
148141
val elemTp = unapplySeqTypeElemTp(selTps.last)
149142
(0 until argsNum).map(i => if (i < arity - 1) selTps(i) else elemTp).toList
150143
}
151144

152-
def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: SrcPos)(using Context): List[Type] =
153-
def getName(fn: Tree): Name =
145+
/** A utility class that matches results of unapplys with patterns. Two queriable members:
146+
* val argTypes: List[Type]
147+
* def typedPatterns(qual: untpd.Tree, typer: Typer): List[Tree]
148+
* TODO: Move into Applications trait. No need to keep it outside. But it's a large
149+
* refactor, so do this when the rest is merged.
150+
*/
151+
class UnapplyArgs(unapplyResult: Type, unapplyFn: Tree, unadaptedArgs: List[untpd.Tree], pos: SrcPos)(using Context):
152+
private var args = unadaptedArgs
153+
154+
private def getName(fn: Tree): Name =
154155
fn match
155156
case TypeApply(fn, _) => getName(fn)
156157
case Apply(fn, _) => getName(fn)
157158
case fn: RefTree => fn.name
158-
val unapplyName = getName(unapplyFn) // tolerate structural `unapply`, which does not have a symbol
159+
private val unapplyName = getName(unapplyFn) // tolerate structural `unapply`, which does not have a symbol
159160

160-
def getTp = extractorMemberType(unapplyResult, nme.get, pos)
161+
private def getTp = extractorMemberType(unapplyResult, nme.get, pos)
161162

162-
def fail = {
163+
private def fail = {
163164
report.error(UnapplyInvalidReturnType(unapplyResult, unapplyName), pos)
164165
Nil
165166
}
166167

167-
def unapplySeq(tp: Type)(fallback: => List[Type]): List[Type] =
168+
private def unapplySeq(tp: Type)(fallback: => List[Type]): List[Type] =
168169
val elemTp = unapplySeqTypeElemTp(tp)
169170
if elemTp.exists then
170171
args.map(Function.const(elemTp))
@@ -174,26 +175,84 @@ object Applications {
174175
tp.tupleElementTypes.getOrElse(Nil)
175176
else fallback
176177

177-
if unapplyName == nme.unapplySeq then
178-
unapplySeq(unapplyResult):
179-
if (isGetMatch(unapplyResult, pos)) unapplySeq(getTp)(fail)
180-
else fail
181-
else
182-
assert(unapplyName == nme.unapply)
183-
if isProductMatch(unapplyResult, args.length, pos) then
184-
productSelectorTypes(unapplyResult, pos)
185-
else if isGetMatch(unapplyResult, pos) then
186-
getUnapplySelectors(getTp, args, pos)
187-
else if unapplyResult.derivesFrom(defn.BooleanClass) then
188-
Nil
189-
else if defn.isProductSubType(unapplyResult) && productArity(unapplyResult, pos) != 0 then
190-
productSelectorTypes(unapplyResult, pos)
191-
// this will cause a "wrong number of arguments in pattern" error later on,
192-
// which is better than the message in `fail`.
193-
else if unapplyResult.derivesFrom(defn.NonEmptyTupleClass) then
194-
unapplyResult.tupleElementTypes.getOrElse(Nil)
195-
else fail
196-
end unapplyArgs
178+
private def tryAdaptPatternArgs(elems: List[untpd.Tree], pt: Type)(using Context): Option[List[untpd.Tree]] =
179+
tryEither[Option[List[untpd.Tree]]]
180+
(Some(desugar.adaptPatternArgs(elems, pt)))
181+
((_, _) => None)
182+
183+
private def getUnapplySelectors(tp: Type)(using Context): List[Type] =
184+
if args.length > 1 && !(tp.derivesFrom(defn.SeqClass)) then
185+
productUnapplySelectors(tp).getOrElse:
186+
// There are unapplys with return types which have `get` and `_1, ..., _n`
187+
// as members, but which are not subtypes of Product. So `productUnapplySelectors`
188+
// would return None for these, but they are still valid types
189+
// for a get match. A test case is pos/extractors.scala.
190+
val sels = productSelectorTypes(tp, pos)
191+
if (sels.length == args.length) sels
192+
else tp :: Nil
193+
else tp :: Nil
194+
195+
private def productUnapplySelectors(tp: Type)(using Context): Option[List[Type]] =
196+
if defn.isProductSubType(tp) then
197+
tryAdaptPatternArgs(args, tp) match
198+
case Some(args1) if isProductMatch(tp, args1.length, pos) =>
199+
args = args1
200+
Some(productSelectorTypes(tp, pos))
201+
case _ => None
202+
else tp.widen.normalized.dealias match
203+
case tp @ defn.NamedTuple(_, tt) =>
204+
tryAdaptPatternArgs(args, tp) match
205+
case Some(args1) =>
206+
args = args1
207+
tt.tupleElementTypes
208+
case _ => None
209+
case _ => None
210+
211+
/** The computed argument types which will be the scutinees of the sub-patterns. */
212+
val argTypes: List[Type] =
213+
if unapplyName == nme.unapplySeq then
214+
unapplySeq(unapplyResult):
215+
if (isGetMatch(unapplyResult, pos)) unapplySeq(getTp)(fail)
216+
else fail
217+
else
218+
assert(unapplyName == nme.unapply)
219+
productUnapplySelectors(unapplyResult).getOrElse:
220+
if isGetMatch(unapplyResult, pos) then
221+
getUnapplySelectors(getTp)
222+
else if unapplyResult.derivesFrom(defn.BooleanClass) then
223+
Nil
224+
else if unapplyResult.derivesFrom(defn.NonEmptyTupleClass) then
225+
unapplyResult.tupleElementTypes.getOrElse(Nil)
226+
else if defn.isProductSubType(unapplyResult) && productArity(unapplyResult, pos) != 0 then
227+
productSelectorTypes(unapplyResult, pos)
228+
// this will cause a "wrong number of arguments in pattern" error later on,
229+
// which is better than the message in `fail`.
230+
else fail
231+
232+
/** The typed pattens of this unapply */
233+
def typedPatterns(qual: untpd.Tree, typer: Typer): List[Tree] =
234+
unapp.println(i"unapplyQual = $qual, unapplyArgs = ${unapplyResult} with $argTypes / $args")
235+
for argType <- argTypes do
236+
assert(!isBounds(argType), unapplyResult.show)
237+
val alignedArgs = argTypes match
238+
case argType :: Nil
239+
if args.lengthCompare(1) > 0
240+
&& Feature.autoTuplingEnabled
241+
&& defn.isTupleNType(argType) =>
242+
untpd.Tuple(args) :: Nil
243+
case _ =>
244+
args
245+
val alignedArgTypes =
246+
if argTypes.length == alignedArgs.length then
247+
argTypes
248+
else
249+
report.error(UnapplyInvalidNumberOfArguments(qual, argTypes), pos)
250+
argTypes.take(args.length) ++
251+
List.fill(argTypes.length - args.length)(WildcardType)
252+
alignedArgs.lazyZip(alignedArgTypes).map(typer.typed(_, _))
253+
.showing(i"unapply patterns = $result", unapp)
254+
255+
end UnapplyArgs
197256

198257
def wrapDefs(defs: mutable.ListBuffer[Tree] | Null, tree: Tree)(using Context): Tree =
199258
if (defs != null && defs.nonEmpty) tpd.Block(defs.toList, tree) else tree
@@ -1452,28 +1511,10 @@ trait Applications extends Compatibility {
14521511
loop(unapp)
14531512
res.result()
14541513
}
1455-
val args = desugar.adaptPatternArgs(unadaptedArgs, unapplyApp.tpe)
1456-
1457-
var argTypes = unapplyArgs(unapplyApp.tpe.stripNamedTuple, unapplyFn, args, tree.srcPos)
1458-
unapp.println(i"unapplyArgs = ${unapplyApp.tpe} with $argTypes / $args")
1459-
for (argType <- argTypes) assert(!isBounds(argType), unapplyApp.tpe.show)
1460-
val bunchedArgs = argTypes match {
1461-
case argType :: Nil =>
1462-
if args.lengthCompare(1) > 0
1463-
&& Feature.autoTuplingEnabled
1464-
&& defn.isTupleNType(argType)
1465-
then untpd.Tuple(args) :: Nil
1466-
else args
1467-
case _ => args
1468-
}
1469-
if (argTypes.length != bunchedArgs.length) {
1470-
report.error(UnapplyInvalidNumberOfArguments(qual, argTypes), tree.srcPos)
1471-
argTypes = argTypes.take(args.length) ++
1472-
List.fill(argTypes.length - args.length)(WildcardType)
1473-
}
1474-
val unapplyPatterns = bunchedArgs.lazyZip(argTypes) map (typed(_, _))
1514+
1515+
val unapplyPatterns = UnapplyArgs(unapplyApp.tpe, unapplyFn, unadaptedArgs, tree.srcPos)
1516+
.typedPatterns(qual, this)
14751517
val result = assignType(cpy.UnApply(tree)(unapplyFn, unapplyImplicits(unapplyApp), unapplyPatterns), ownType)
1476-
unapp.println(s"unapply patterns = $unapplyPatterns")
14771518
if (ownType.stripped eq selType.stripped) || ownType.isError then result
14781519
else tryWithTypeTest(Typed(result, TypeTree(ownType)), selType)
14791520
case tp =>

compiler/src/dotty/tools/dotc/typer/Checking.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import config.Printers.{typr, patmatch}
2929
import NameKinds.DefaultGetterName
3030
import NameOps.*
3131
import SymDenotations.{NoCompleter, NoDenotation}
32-
import Applications.unapplyArgs
32+
import Applications.UnapplyArgs
3333
import Inferencing.isFullyDefined
3434
import transform.patmat.SpaceEngine.{isIrrefutable, isIrrefutableQuotePattern}
3535
import transform.ValueClasses.underlyingOfValueClass
@@ -952,7 +952,7 @@ trait Checking {
952952
case UnApply(fn, implicits, pats) =>
953953
check(pat, pt) &&
954954
(isIrrefutable(fn, pats.length) || fail(pat, pt, Reason.RefutableExtractor)) && {
955-
val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.srcPos)
955+
val argPts = UnapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.srcPos).argTypes
956956
pats.corresponds(argPts)(recur)
957957
}
958958
case Alternative(pats) =>

compiler/src/dotty/tools/dotc/typer/Typer.scala

+25-25
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,31 @@ object Typer {
113113
def rememberSearchFailure(tree: tpd.Tree, fail: SearchFailure) =
114114
tree.putAttachment(HiddenSearchFailure,
115115
fail :: tree.attachmentOrElse(HiddenSearchFailure, Nil))
116+
117+
def tryEither[T](op: Context ?=> T)(fallBack: (T, TyperState) => T)(using Context): T = {
118+
val nestedCtx = ctx.fresh.setNewTyperState()
119+
val result = op(using nestedCtx)
120+
if (nestedCtx.reporter.hasErrors && !nestedCtx.reporter.hasStickyErrors) {
121+
record("tryEither.fallBack")
122+
fallBack(result, nestedCtx.typerState)
123+
}
124+
else {
125+
record("tryEither.commit")
126+
nestedCtx.typerState.commit()
127+
result
128+
}
129+
}
130+
131+
/** Try `op1`, if there are errors, try `op2`, if `op2` also causes errors, fall back
132+
* to errors and result of `op1`.
133+
*/
134+
def tryAlternatively[T](op1: Context ?=> T)(op2: Context ?=> T)(using Context): T =
135+
tryEither(op1) { (failedVal, failedState) =>
136+
tryEither(op2) { (_, _) =>
137+
failedState.commit()
138+
failedVal
139+
}
140+
}
116141
}
117142
/** Typecheck trees, the main entry point is `typed`.
118143
*
@@ -3461,31 +3486,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
34613486
def typedPattern(tree: untpd.Tree, selType: Type = WildcardType)(using Context): Tree =
34623487
withMode(Mode.Pattern)(typed(tree, selType))
34633488

3464-
def tryEither[T](op: Context ?=> T)(fallBack: (T, TyperState) => T)(using Context): T = {
3465-
val nestedCtx = ctx.fresh.setNewTyperState()
3466-
val result = op(using nestedCtx)
3467-
if (nestedCtx.reporter.hasErrors && !nestedCtx.reporter.hasStickyErrors) {
3468-
record("tryEither.fallBack")
3469-
fallBack(result, nestedCtx.typerState)
3470-
}
3471-
else {
3472-
record("tryEither.commit")
3473-
nestedCtx.typerState.commit()
3474-
result
3475-
}
3476-
}
3477-
3478-
/** Try `op1`, if there are errors, try `op2`, if `op2` also causes errors, fall back
3479-
* to errors and result of `op1`.
3480-
*/
3481-
def tryAlternatively[T](op1: Context ?=> T)(op2: Context ?=> T)(using Context): T =
3482-
tryEither(op1) { (failedVal, failedState) =>
3483-
tryEither(op2) { (_, _) =>
3484-
failedState.commit()
3485-
failedVal
3486-
}
3487-
}
3488-
34893489
/** Is `pt` a prototype of an `apply` selection, or a parameterless function yielding one? */
34903490
def isApplyProto(pt: Type)(using Context): Boolean = pt.revealIgnored match {
34913491
case pt: SelectionProto => pt.name == nme.apply

tests/run/named-patterns.check

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ name Bob
33
age 22
44
age 22, name Bob
55
Bob, 22
6+
name Bob, age 22
7+
name (Bob,22)
8+
age (Bob,22)
9+
age 22, name Bob
10+
Bob, 22
611
1003 Lausanne, Rue de la Gare 44
712
1003 Lausanne
813
Rue de la Gare in Lausanne

tests/run/named-patterns.scala

+16-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ object Test1:
66
object Person:
77
def unapply(p: Person): (name: String, age: Int) = (p.name, p.age)
88

9+
class Person2(val name: String, val age: Int)
10+
object Person2:
11+
def unapply(p: Person2): Option[(name: String, age: Int)] = Some((p.name, p.age))
12+
913
case class Address(city: String, zip: Int, street: String, number: Int)
1014

1115
@main def Test =
@@ -21,6 +25,18 @@ object Test1:
2125
bob match
2226
case Person(age, name) => println(s"$age, $name")
2327

28+
val bob2 = Person2("Bob", 22)
29+
bob2 match
30+
case Person2(name = n, age = a) => println(s"name $n, age $a")
31+
bob2 match
32+
case Person2(name = n) => println(s"name $n")
33+
bob2 match
34+
case Person2(age = a) => println(s"age $a")
35+
bob2 match
36+
case Person2(age = a, name = n) => println(s"age $a, name $n")
37+
bob2 match
38+
case Person2(age, name) => println(s"$age, $name")
39+
2440
val addr = Address("Lausanne", 1003, "Rue de la Gare", 44)
2541
addr match
2642
case Address(city = c, zip = z, street = s, number = n) =>
@@ -37,7 +53,3 @@ object Test1:
3753
addr match
3854
case Address(c, z, s, number) =>
3955
println(s"$z $c, $s $number")
40-
41-
42-
43-

0 commit comments

Comments
 (0)