Skip to content

Commit ac538f2

Browse files
committed
synthesize mirrors for small generic tuples
- handles generic tuples of different arity
1 parent 9d2d194 commit ac538f2

File tree

12 files changed

+145
-23
lines changed

12 files changed

+145
-23
lines changed

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
689689
}
690690

691691
private def erasePair(tp: Type)(using Context): Type = {
692-
val arity = tp.tupleArity
692+
// NOTE: `tupleArity` does not consider TypeRef(EmptyTuple$) equivalent to EmptyTuple.type,
693+
// we fix this for printers, but type erasure should be preserved.
694+
val arity = tp.tupleArity()
693695
if (arity < 0) defn.ProductClass.typeRef
694696
else if (arity <= Definitions.MaxTupleArity) defn.TupleType(arity).nn
695697
else defn.TupleXXLClass.typeRef

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,22 @@ import scala.annotation.internal.sharable
4040
import scala.annotation.threadUnsafe
4141

4242
import dotty.tools.dotc.transform.SymUtils._
43+
import dotty.tools.dotc.transform.TypeUtils.*
4344

4445
object Types {
4546

4647
@sharable private var nextId = 0
4748

4849
implicit def eqType: CanEqual[Type, Type] = CanEqual.derived
4950

51+
object GenericTupleType:
52+
def unapply(tp: AppliedType)(using Context): Option[List[Type]] =
53+
tp.tycon match
54+
case TypeRef(_, cls: ClassSymbol) if tp.tupleArity(relaxEmptyTuple = true) > 0 => // avoid type aliases
55+
Some(tp.tupleElementTypes)
56+
case _ =>
57+
None
58+
5059
/** Main class representing types.
5160
*
5261
* The principal subclasses and sub-objects are as follows:

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
218218
val cls = tycon.typeSymbol
219219
if tycon.isRepeatedParam then toTextLocal(args.head) ~ "*"
220220
else if defn.isFunctionClass(cls) then toTextFunction(args, cls.name.isContextFunction, cls.name.isErasedFunction)
221-
else if tp.tupleArity >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes)
221+
else if tp.tupleArity(relaxEmptyTuple = true) >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes)
222222
else if isInfixType(tp) then
223223
val l :: r :: Nil = args: @unchecked
224224
val opName = tyconName(tycon)

compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ object GenericSignatures {
248248
case _ => jsig(elemtp)
249249

250250
case RefOrAppliedType(sym, pre, args) =>
251-
if (sym == defn.PairClass && tp.tupleArity > Definitions.MaxTupleArity)
251+
if (sym == defn.PairClass && tp.tupleArity() > Definitions.MaxTupleArity)
252252
jsig(defn.TupleXXLClass.typeRef)
253253
else if (isTypeParameterInSig(sym, sym0)) {
254254
assert(!sym.isAliasType, "Unexpected alias type: " + sym)

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ object SyntheticMembers {
2626

2727
/** Attachment recording that an anonymous class should extend Mirror.Sum */
2828
val ExtendsSumMirror: Property.StickyKey[Unit] = new Property.StickyKey
29+
30+
/** Attachment recording that an anonymous class should extend Mirror.Sum */
31+
val GenericTupleArity: Property.StickyKey[Int] = new Property.StickyKey
2932
}
3033

3134
/** Synthetic method implementations for case classes, case objects,
@@ -601,7 +604,11 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
601604
else if (impl.removeAttachment(ExtendsSingletonMirror).isDefined)
602605
makeSingletonMirror()
603606
else if (impl.removeAttachment(ExtendsProductMirror).isDefined)
604-
makeProductMirror(monoType.typeRef.dealias.classSymbol)
607+
val tupleArity = impl.removeAttachment(GenericTupleArity)
608+
val cls = tupleArity match
609+
case Some(n) => defn.TupleType(n).nn.classSymbol
610+
case _ => monoType.typeRef.dealias.classSymbol
611+
makeProductMirror(cls)
605612
else if (impl.removeAttachment(ExtendsSumMirror).isDefined)
606613
makeSumMirror(monoType.typeRef.dealias.classSymbol)
607614

compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,17 @@ object TypeUtils {
5151

5252
/** The arity of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs,
5353
* or -1 if this is not a tuple type.
54+
*
55+
* @param relaxEmptyTuple if true then TypeRef(EmptyTuple$) =:= EmptyTuple.type
5456
*/
55-
def tupleArity(using Context): Int = self match {
57+
def tupleArity(relaxEmptyTuple: Boolean = false)(using Context): Int = self match {
5658
case AppliedType(tycon, _ :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
57-
val arity = tl.tupleArity
59+
val arity = tl.tupleArity(relaxEmptyTuple)
5860
if (arity < 0) arity else arity + 1
5961
case self: SingletonType =>
6062
if self.termSymbol == defn.EmptyTupleModule then 0 else -1
63+
case self: TypeRef if relaxEmptyTuple && self.classSymbol == defn.EmptyTupleModule.moduleClass =>
64+
0
6165
case self if defn.isTupleClass(self.classSymbol) =>
6266
self.dealias.argInfos.length
6367
case _ =>
@@ -69,12 +73,14 @@ object TypeUtils {
6973
case AppliedType(tycon, hd :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
7074
hd :: tl.tupleElementTypes
7175
case self: SingletonType =>
72-
assert(self.termSymbol == defn.EmptyTupleModule, "not a tuple")
76+
assert(self.termSymbol == defn.EmptyTupleModule, i"not a tuple `$self`")
77+
Nil
78+
case self: TypeRef if self.classSymbol == defn.EmptyTupleModule.moduleClass =>
7379
Nil
7480
case self if defn.isTupleClass(self.classSymbol) =>
7581
self.dealias.argInfos
76-
case _ =>
77-
throw new AssertionError("not a tuple")
82+
case tp =>
83+
throw new AssertionError(i"not a tuple `$tp`")
7884
}
7985

8086
/** The `*:` equivalent of an instance of a Tuple class */

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,19 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
223223
/** Create an anonymous class `new Object { type MirroredMonoType = ... }`
224224
* and mark it with given attachment so that it is made into a mirror at PostTyper.
225225
*/
226-
private def anonymousMirror(monoType: Type, attachment: Property.StickyKey[Unit], span: Span)(using Context) =
226+
private def anonymousMirror(monoType: Type, attachment: Property.StickyKey[Unit], tupleArity: Option[Int], span: Span)(using Context) =
227227
if ctx.isAfterTyper then ctx.compilationUnit.needsMirrorSupport = true
228228
val monoTypeDef = untpd.TypeDef(tpnme.MirroredMonoType, untpd.TypeTree(monoType))
229-
val newImpl = untpd.Template(
229+
var newImpl = untpd.Template(
230230
constr = untpd.emptyConstructor,
231231
parents = untpd.TypeTree(defn.ObjectType) :: Nil,
232232
derived = Nil,
233233
self = EmptyValDef,
234234
body = monoTypeDef :: Nil
235235
).withAttachment(attachment, ())
236+
tupleArity.foreach { n =>
237+
newImpl = newImpl.withAttachment(GenericTupleArity, n)
238+
}
236239
typer.typed(untpd.New(newImpl).withSpan(span))
237240

238241
/** The mirror type
@@ -279,6 +282,20 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
279282
private[Synthesizer] enum MirrorSource:
280283
case ClassSymbol(cls: Symbol)
281284
case Singleton(src: Symbol, tref: TermRef)
285+
case GenericTuple(tuplArity: Int, tpArgs: List[Type])
286+
287+
def whyNotGenericProd(using Context): Option[String] = this match
288+
case GenericTuple(arity, _) =>
289+
val maxArity = Definitions.MaxTupleArity
290+
if arity <= maxArity then None
291+
else Some(i"it reduces to tuple with arity $arity, expected arity <= $maxArity")
292+
case ClassSymbol(cls) => if cls.isGenericProduct then None else Some(cls.whyNotGenericProduct)
293+
case _ => None
294+
295+
/** tuple arity, works for TupleN classes and generic tuples */
296+
final def arity(using Context): Int = this match
297+
case GenericTuple(arity, _) => arity
298+
case _ => -1
282299

283300
/** A comparison that chooses the most specific MirrorSource, this is guided by what is necessary for
284301
* `Mirror.Product.fromProduct`. i.e. its result type should be compatible with the erasure of `mirroredType`.
@@ -289,12 +306,29 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
289306
case (ClassSymbol(cls1), ClassSymbol(cls2)) => cls1.isSubClass(cls2)
290307
case (Singleton(src1, _), Singleton(src2, _)) => src1 eq src2
291308
case (_: ClassSymbol, _: Singleton) => false
309+
case _ => (this.arity, that.arity) match
310+
case (n, m) if n > 0 || m > 0 =>
311+
// we shortcut when at least one was a tuple.
312+
// This protects us from comparing classes for two TupleXXL with different arities.
313+
n == m
314+
case _ => false
315+
316+
def asClass(using Context): Symbol = this match
317+
case ClassSymbol(cls) => cls
318+
case Singleton(src, _) => src.info.classSymbol
319+
case GenericTuple(arity, _) =>
320+
if arity <= Definitions.MaxTupleArity then defn.TupleType(arity).nn.classSymbol
321+
else defn.PairClass
292322

293323
def show(using Context): String = this match
294324
case ClassSymbol(cls) => i"$cls"
295325
case Singleton(src, _) => i"$src"
326+
case GenericTuple(arity, _) =>
327+
if arity <= Definitions.MaxTupleArity then i"class Tuple$arity"
328+
else i"trait Tuple { def size: $arity }"
296329

297330
private[Synthesizer] object MirrorSource:
331+
type WithArity = ClassSymbol | GenericTuple
298332

299333
/** Reduces a mirroredType to either its most specific ClassSymbol,
300334
* or a TermRef to a singleton value. These are
@@ -332,6 +366,9 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
332366
reduce(tp.underlying)
333367
case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] =>
334368
Left(i"its subpart `$tp` is not a supported kind (either `*` or `* -> *`)")
369+
case tp @ GenericTupleType(args) =>
370+
val arity = args.size
371+
Right(MirrorSource.GenericTuple(arity, args))
335372
case tp: TypeProxy =>
336373
reduce(tp.underlying)
337374
case tp @ AndType(l, r) =>
@@ -354,10 +391,22 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
354391

355392
private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =
356393

357-
def makeProductMirror(cls: Symbol): TreeWithErrors =
358-
val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
394+
/** widen TermRef to see if they are an alias to an enum singleton */
395+
def isEnumSingletonRef(tp: Type)(using Context): Boolean = tp match
396+
case tp: TermRef =>
397+
val sym = tp.termSymbol
398+
sym.isEnumCase || (!tp.isOverloaded && isEnumSingletonRef(tp.underlying.widenExpr))
399+
case _ => false
400+
401+
def makeProductMirror(msrc: MirrorSource.WithArity): TreeWithErrors =
402+
val mirroredClass = msrc.asClass
403+
val accessors = mirroredClass.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
359404
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
360-
val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
405+
val (arity, nestedPairs) = msrc match
406+
case MirrorSource.GenericTuple(arity, args) =>
407+
(Some(arity), TypeOps.nestedPairs(args))
408+
case MirrorSource.ClassSymbol(cls) =>
409+
(None, TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)))
361410
val (monoType, elemsType) = mirroredType match
362411
case mirroredType: HKTypeLambda =>
363412
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
@@ -367,12 +416,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
367416
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
368417
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
369418
val mirrorType =
370-
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
419+
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, mirroredClass.name, formal)
371420
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
372421
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
373422
val mirrorRef =
374-
if cls.useCompanionAsProductMirror then companionPath(mirroredType, span)
375-
else anonymousMirror(monoType, ExtendsProductMirror, span)
423+
if mirroredClass.useCompanionAsProductMirror then companionPath(mirroredType, span)
424+
else anonymousMirror(monoType, ExtendsProductMirror, arity, span)
376425
withNoErrors(mirrorRef.cast(mirrorType))
377426
end makeProductMirror
378427

@@ -389,9 +438,10 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
389438
else
390439
val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, singleton.name, formal)
391440
withNoErrors(singletonPath.cast(mirrorType))
392-
case MirrorSource.ClassSymbol(cls) =>
393-
if cls.isGenericProduct then makeProductMirror(cls)
394-
else withErrors(i"$cls is not a generic product because ${cls.whyNotGenericProduct}")
441+
case msrc: MirrorSource.WithArity =>
442+
msrc.whyNotGenericProd match
443+
case Some(err) => withErrors(i"${msrc.asClass} is not a generic product because $err")
444+
case _ => makeProductMirror(msrc)
395445
case Left(msg) =>
396446
withErrors(i"type `$mirroredType` is not a generic product because $msg")
397447
end productMirror
@@ -400,7 +450,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
400450

401451
val (acceptableMsg, cls) = MirrorSource.reduce(mirroredType) match
402452
case Right(MirrorSource.Singleton(_, tp)) => (i"its subpart `$tp` is a term reference", NoSymbol)
403-
case Right(MirrorSource.ClassSymbol(cls)) => ("", cls)
453+
case Right(msrc) => ("", msrc.asClass)
404454
case Left(msg) => (msg, NoSymbol)
405455

406456
val clsIsGenericSum = cls.isGenericSum
@@ -457,7 +507,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
457507
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
458508
val mirrorRef =
459509
if cls.useCompanionAsSumMirror then companionPath(mirroredType, span)
460-
else anonymousMirror(monoType, ExtendsSumMirror, span)
510+
else anonymousMirror(monoType, ExtendsSumMirror, tupleArity = None, span)
461511
withNoErrors(mirrorRef.cast(mirrorType))
462512
else if acceptableMsg.nonEmpty then
463513
withErrors(i"type `$mirroredType` is not a generic sum because $acceptableMsg")

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2767,7 +2767,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
27672767
typed(desugar.smallTuple(tree).withSpan(tree.span), pt)
27682768
else {
27692769
val pts =
2770-
if (arity == pt.tupleArity) pt.tupleElementTypes
2770+
if (arity == pt.tupleArity()) pt.tupleElementTypes
27712771
else List.fill(arity)(defn.AnyType)
27722772
val elems = tree.trees.lazyZip(pts).map(
27732773
if ctx.mode.is(Mode.Type) then typedType(_, _, mapPatternBounds = true)

tests/neg/i14127.check

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
-- Error: tests/neg/i14127.scala:6:55 ----------------------------------------------------------------------------------
2+
6 | *: Int *: Int *: Int *: Int *: Int *: EmptyTuple)]] // error
3+
| ^
4+
|No given instance of type deriving.Mirror.Of[(Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int,
5+
| Int
6+
|, Int, Int)] was found for parameter x of method summon in object Predef. Failed to synthesize an instance of type deriving.Mirror.Of[(Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int,
7+
| Int
8+
|, Int, Int)]:
9+
| * class *: is not a generic product because it reduces to tuple with arity 23, expected arity <= 22
10+
| * class *: is not a generic sum because it does not have subclasses

tests/neg/i14127.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import scala.deriving.Mirror
2+
3+
val mT23 = summon[Mirror.Of[(
4+
Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
5+
*: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
6+
*: Int *: Int *: Int *: Int *: Int *: EmptyTuple)]] // error

tests/run/i14127.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import scala.deriving.Mirror
2+
3+
@main def Test =
4+
val mISB = summon[Mirror.Of[Int *: String *: Boolean *: EmptyTuple]]
5+
assert(mISB.fromProduct((1, "foo", true)) == (1, "foo", true))
6+
7+
val mT22 = summon[Mirror.Of[(
8+
Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
9+
*: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
10+
*: Int *: Int *: Int *: Int *: EmptyTuple)]]
11+
12+
// tuple of 22 elements
13+
val t22 = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)
14+
assert(mT22.fromProduct(t22) == t22)

tests/run/i7049.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.deriving._
2+
3+
case class Foo(x: Int, y: String)
4+
5+
def toTuple[T <: Product](x: T)(using m: Mirror.ProductOf[T], mt: Mirror.ProductOf[m.MirroredElemTypes]) =
6+
mt.fromProduct(x)
7+
8+
@main def Test = {
9+
val m = summon[Mirror.ProductOf[Foo]]
10+
val mt1 = summon[Mirror.ProductOf[(Int, String)]]
11+
type R = (Int, String)
12+
val mt2 = summon[Mirror.ProductOf[R]]
13+
val mt3 = summon[Mirror.ProductOf[m.MirroredElemTypes]]
14+
15+
val f = Foo(1, "foo")
16+
val g: (Int, String) = toTuple(f)// (using m, mt1)
17+
assert(g == (1, "foo"))
18+
}

0 commit comments

Comments
 (0)