Skip to content

Commit 0b90f35

Browse files
committed
cleaner way to extract class or tuple proxy
1 parent 0e09603 commit 0b90f35

File tree

7 files changed

+138
-53
lines changed

7 files changed

+138
-53
lines changed

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
689689
}
690690

691691
private def erasePair(tp: Type)(using Context): Type = {
692-
val arity = tp.tupleArity
692+
// NOTE: `tupleArity` does not consider TypeRef(EmptyTuple$) equivalent to EmptyTuple.type,
693+
// we fix this for printers, but type erasure should be preserved.
694+
val arity = tp.tupleArity()
693695
if (arity < 0) defn.ProductClass.typeRef
694696
else if (arity <= Definitions.MaxTupleArity) defn.TupleType(arity).nn
695697
else defn.TupleXXLClass.typeRef

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,71 @@ import scala.annotation.internal.sharable
4040
import scala.annotation.threadUnsafe
4141

4242
import dotty.tools.dotc.transform.SymUtils._
43+
import dotty.tools.dotc.transform.TypeUtils.*
4344

4445
object Types {
4546

4647
@sharable private var nextId = 0
4748

4849
implicit def eqType: CanEqual[Type, Type] = CanEqual.derived
4950

51+
object GenericTupleType:
52+
def unapply(tp: Type)(using Context): Option[List[Type]] = tp match
53+
case tp @ AppliedType(ref: TypeRef, _)
54+
if ref.isRef(defn.PairClass) && tp.tupleArity(relaxEmptyTuple = true) > 0 =>
55+
Some(tp.tupleElementTypes)
56+
case _ => None
57+
58+
enum ClassOrTuple:
59+
case ClassSymbol(cls: Symbol)
60+
case GenericTuple(tuplArity: Int, tpArgs: List[Type])
61+
case NoClass
62+
63+
def isGenericTuple: Boolean = this.isInstanceOf[GenericTuple]
64+
65+
private var _asClass: Symbol | Null = null
66+
67+
/** tuple arity, works for TupleN classes and generic tuples */
68+
final def arity(using Context): Int = this match
69+
case GenericTuple(arity, _) => arity
70+
case ClassSymbol(cls) =>
71+
if defn.isTupleClass(cls) then
72+
cls.typeParams.length
73+
else
74+
-1
75+
case NoClass => -1
76+
77+
def equiv(that: ClassOrTuple)(using Context): Boolean = (this.arity, that.arity) match
78+
case (n, m) if n > 0 || m > 0 =>
79+
// we shortcut when at least one was a tuple.
80+
// This protects us from comparing classes for two TupleXXL with different arities.
81+
n == m
82+
case _ => this.asClass eq that.asClass // class equality otherwise
83+
84+
def isSub(that: ClassOrTuple)(using Context): Boolean = (this.arity, that.arity) match
85+
case (n, m) if n > 0 || m > 0 =>
86+
// we shortcut when at least one was a tuple.
87+
// This protects us from comparing classes for two TupleXXL with different arities.
88+
n == m
89+
case _ => this.asClass isSubClass that.asClass
90+
91+
def asClass(using Context): Symbol =
92+
val local = _asClass
93+
if local == null then
94+
val res = this match
95+
case ClassSymbol(cls) => cls
96+
case GenericTuple(arity, _) =>
97+
if arity <= Definitions.MaxTupleArity then defn.TupleType(arity).nn.classSymbol
98+
else defn.TupleXXLClass
99+
case NoClass => NoSymbol
100+
_asClass = res
101+
res
102+
else local
103+
object ClassOrTuple:
104+
def tuple(tps: List[Type]): ClassOrTuple = ClassOrTuple.GenericTuple(tps.size, tps)
105+
106+
end ClassOrTuple
107+
50108
/** Main class representing types.
51109
*
52110
* The principal subclasses and sub-objects are as follows:
@@ -491,6 +549,8 @@ object Types {
491549
/** The least class or trait of which this type is a subtype or parameterized
492550
* instance, or NoSymbol if none exists (either because this type is not a
493551
* value type, or because superclasses are ambiguous).
552+
*
553+
* If modified, update [[underlyingClassOrTuple]] as well.
494554
*/
495555
final def classSymbol(using Context): Symbol = this match
496556
case tp: TypeRef =>
@@ -524,6 +584,41 @@ object Types {
524584
case _ =>
525585
NoSymbol
526586

587+
/** Follows [[classSymbol]], but also escapes generic tuples to a proxy representing their class and type arguments.
588+
*/
589+
final def underlyingClassOrTuple(using Context): ClassOrTuple = this match
590+
case tp: TypeRef =>
591+
val sym = tp.symbol
592+
if (sym.isClass) ClassOrTuple.ClassSymbol(sym) else tp.superType.underlyingClassOrTuple
593+
case GenericTupleType(args) => ClassOrTuple.tuple(args)
594+
case tp: TypeProxy =>
595+
tp.underlying.underlyingClassOrTuple
596+
case tp: ClassInfo =>
597+
ClassOrTuple.ClassSymbol(tp.cls)
598+
case AndType(l, r) =>
599+
val lsym = l.underlyingClassOrTuple
600+
val rsym = r.underlyingClassOrTuple
601+
if (lsym isSub rsym) lsym
602+
else if (rsym isSub lsym) rsym
603+
else ClassOrTuple.NoClass
604+
case tp: OrType =>
605+
if tp.tp1.hasClassSymbol(defn.NothingClass) then
606+
tp.tp2.underlyingClassOrTuple
607+
else if tp.tp2.hasClassSymbol(defn.NothingClass) then
608+
tp.tp1.underlyingClassOrTuple
609+
else
610+
def tp1Null = tp.tp1.hasClassSymbol(defn.NullClass)
611+
def tp2Null = tp.tp2.hasClassSymbol(defn.NullClass)
612+
if ctx.erasedTypes && (tp1Null || tp2Null) then
613+
val otherSide = if tp1Null then tp.tp2.underlyingClassOrTuple else tp.tp1.underlyingClassOrTuple
614+
if otherSide.asClass.isValueClass then ClassOrTuple.ClassSymbol(defn.AnyClass) else otherSide
615+
else
616+
tp.join.underlyingClassOrTuple
617+
case _: JavaArrayType =>
618+
ClassOrTuple.ClassSymbol(defn.ArrayClass)
619+
case _ =>
620+
ClassOrTuple.NoClass
621+
527622
/** The least (wrt <:<) set of symbols satisfying the `include` predicate of which this type is a subtype
528623
*/
529624
final def parentSymbols(include: Symbol => Boolean)(using Context): List[Symbol] = this match {

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
218218
val cls = tycon.typeSymbol
219219
if tycon.isRepeatedParam then toTextLocal(args.head) ~ "*"
220220
else if defn.isFunctionClass(cls) then toTextFunction(args, cls.name.isContextFunction, cls.name.isErasedFunction)
221-
else if tp.tupleArity >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes)
221+
else if tp.tupleArity(relaxEmptyTuple = true) >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes)
222222
else if isInfixType(tp) then
223223
val l :: r :: Nil = args: @unchecked
224224
val opName = tyconName(tycon)

compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ object GenericSignatures {
248248
case _ => jsig(elemtp)
249249

250250
case RefOrAppliedType(sym, pre, args) =>
251-
if (sym == defn.PairClass && tp.tupleArity > Definitions.MaxTupleArity)
251+
if (sym == defn.PairClass && tp.tupleArity() > Definitions.MaxTupleArity)
252252
jsig(defn.TupleXXLClass.typeRef)
253253
else if (isTypeParameterInSig(sym, sym0)) {
254254
assert(!sym.isAliasType, "Unexpected alias type: " + sym)

compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,16 @@ object TypeUtils {
5151

5252
/** The arity of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs,
5353
* or -1 if this is not a tuple type.
54+
*
55+
* @param relaxEmptyTuple if true then TypeRef(EmptyTuple$) =:= EmptyTuple.type
5456
*/
55-
def tupleArity(using Context): Int = self match {
57+
def tupleArity(relaxEmptyTuple: Boolean = false)(using Context): Int = self match {
5658
case AppliedType(tycon, _ :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
57-
val arity = tl.tupleArity
59+
val arity = tl.tupleArity()
5860
if (arity < 0) arity else arity + 1
5961
case self: SingletonType =>
6062
if self.termSymbol == defn.EmptyTupleModule then 0 else -1
61-
case self: TypeRef if self.classSymbol == defn.EmptyTupleModule.moduleClass =>
63+
case self: TypeRef if relaxEmptyTuple && self.classSymbol == defn.EmptyTupleModule.moduleClass =>
6264
0
6365
case self if defn.isTupleClass(self.classSymbol) =>
6466
self.dealias.argInfos.length

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -281,34 +281,17 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
281281

282282
private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =
283283

284-
var isSafeGenericTuple = Option.empty[(Symbol, List[Type])]
285-
286-
/** do all parts match the class symbol? Or can we extract a generic tuple type out? */
287-
def acceptable(tp: Type, cls: Symbol): Boolean =
288-
var genericTupleParts = List.empty[(Symbol, List[Type])]
289-
290-
def acceptableGenericTuple(tp: AppliedType): Boolean =
291-
val tupleArgs = tp.tupleElementTypes
292-
val arity = tupleArgs.size
293-
val isOk = arity <= Definitions.MaxTupleArity
294-
if isOk then
295-
genericTupleParts ::= {
296-
val cls = defn.TupleType(arity).nn.classSymbol
297-
(cls, tupleArgs)
298-
}
299-
isOk
300-
301-
def inner(tp: Type, cls: Symbol): Boolean = tp match
302-
case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] => false
303-
case tp @ AppliedType(cons: TypeRef, _) if cons.isRef(defn.PairClass) => acceptableGenericTuple(tp)
304-
case tp: TypeProxy => inner(tp.underlying, cls)
305-
case OrType(tp1, tp2) => inner(tp1, cls) && inner(tp2, cls)
306-
case _ => tp.classSymbol eq cls
284+
extension (clsOrTuple: ClassOrTuple) def isGenericProd(using Context) =
285+
clsOrTuple.isGenericTuple || clsOrTuple.asClass.isGenericProduct && canAccessCtor(clsOrTuple.asClass)
307286

308-
val classPartsMatch = inner(tp, cls)
309-
classPartsMatch && genericTupleParts.map((cls, _) => cls).distinct.sizeIs <= 1 &&
310-
{ isSafeGenericTuple = genericTupleParts.headOption ; true }
311-
end acceptable
287+
/** do all parts match the class symbol? */
288+
def acceptable(tp: Type, clsOrTuple: ClassOrTuple): Boolean = tp match
289+
case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] => false
290+
case OrType(tp1, tp2) => acceptable(tp1, clsOrTuple) && acceptable(tp2, clsOrTuple)
291+
case GenericTupleType(args) if args.size <= Definitions.MaxTupleArity =>
292+
ClassOrTuple.tuple(args).equiv(clsOrTuple)
293+
case tp: TypeProxy => acceptable(tp.underlying, clsOrTuple)
294+
case _ => tp.underlyingClassOrTuple.equiv(clsOrTuple)
312295

313296
/** for a case class, if it will have an anonymous mirror,
314297
* check that its constructor can be accessed
@@ -326,13 +309,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
326309
def genAnonyousMirror(cls: Symbol): Boolean =
327310
cls.is(Scala2x) || cls.linkedClass.is(Case)
328311

329-
def makeProductMirror(cls: Symbol): TreeWithErrors =
330-
val mirroredClass = isSafeGenericTuple.fold(cls)((cls, _) => cls)
312+
def makeProductMirror(clsOrTuple: ClassOrTuple): TreeWithErrors =
313+
val mirroredClass = clsOrTuple.asClass
331314
val accessors = mirroredClass.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
332315
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
333-
val nestedPairs = isSafeGenericTuple.map((_, tps) => TypeOps.nestedPairs(tps)).getOrElse {
334-
TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
335-
}
316+
val nestedPairs = clsOrTuple match
317+
case ClassOrTuple.GenericTuple(_, args) => TypeOps.nestedPairs(args)
318+
case _ => TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
336319
val (monoType, elemsType) = mirroredType match
337320
case mirroredType: HKTypeLambda =>
338321
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
@@ -342,25 +325,30 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
342325
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
343326
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
344327
val mirrorType =
345-
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
328+
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, mirroredClass.name, formal)
346329
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
347330
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
348331
val mirrorRef =
349332
if genAnonyousMirror(mirroredClass) then
350-
anonymousMirror(monoType, ExtendsProductMirror, isSafeGenericTuple.map(_(1).size), span)
333+
val arity = clsOrTuple match
334+
case ClassOrTuple.GenericTuple(arity, _) => Some(arity)
335+
case _ => None
336+
anonymousMirror(monoType, ExtendsProductMirror, arity, span)
351337
else companionPath(mirroredType, span)
352338
withNoErrors(mirrorRef.cast(mirrorType))
353339
end makeProductMirror
354340

355-
def getError(cls: Symbol): String =
341+
def getError(clsOrTuple: ClassOrTuple): String =
356342
val reason =
357-
if !cls.isGenericProduct then
358-
i"because ${cls.whyNotGenericProduct}"
359-
else if !canAccessCtor(cls) then
360-
i"because the constructor of $cls is innaccessible from the calling scope."
343+
if !clsOrTuple.isGenericTuple then
344+
if !clsOrTuple.asClass.isGenericProduct then
345+
i"because ${clsOrTuple.asClass.whyNotGenericProduct}"
346+
else if !canAccessCtor(clsOrTuple.asClass) then
347+
i"because the constructor of ${clsOrTuple.asClass} is innaccessible from the calling scope."
348+
else ""
361349
else
362350
""
363-
i"$cls is not a generic product $reason"
351+
i"${clsOrTuple.asClass} is not a generic product $reason"
364352
end getError
365353

366354
mirroredType match
@@ -378,13 +366,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
378366
val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, module.name, formal)
379367
withNoErrors(modulePath.cast(mirrorType))
380368
else
381-
val cls = mirroredType.classSymbol
382-
if acceptable(mirroredType, cls)
383-
&& isSafeGenericTuple.isDefined || (cls.isGenericProduct && canAccessCtor(cls))
384-
then
385-
makeProductMirror(cls)
369+
val clsOrTuple = mirroredType.underlyingClassOrTuple
370+
if acceptable(mirroredType, clsOrTuple) && clsOrTuple.isGenericProd then
371+
makeProductMirror(clsOrTuple)
386372
else
387-
(EmptyTree, List(getError(cls)))
373+
(EmptyTree, List(getError(clsOrTuple)))
388374
end productMirror
389375

390376
private def sumMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2766,7 +2766,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
27662766
typed(desugar.smallTuple(tree).withSpan(tree.span), pt)
27672767
else {
27682768
val pts =
2769-
if (arity == pt.tupleArity) pt.tupleElementTypes
2769+
if (arity == pt.tupleArity()) pt.tupleElementTypes
27702770
else List.fill(arity)(defn.AnyType)
27712771
val elems = tree.trees.lazyZip(pts).map(
27722772
if ctx.mode.is(Mode.Type) then typedType(_, _, mapPatternBounds = true)

0 commit comments

Comments
 (0)