Skip to content

Commit b46a29f

Browse files
committed
Refactor SAM method compatibility checks using similar logic in Erasure.Boxing.adaptClosure
1 parent 829865b commit b46a29f

File tree

4 files changed

+233
-42
lines changed

4 files changed

+233
-42
lines changed

compiler/src/dotty/tools/dotc/config/JavaPlatform.scala

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,6 @@ class JavaPlatform extends Platform {
4545

4646
def rootLoader(root: TermSymbol)(using Context): SymbolLoader = new SymbolLoaders.PackageLoader(root, classPath)
4747

48-
private def samMethodHasCompatibleBridge(cls: ClassSymbol)(using Context): Boolean =
49-
cls.typeRef.possibleSamMethods match
50-
case Seq(samMeth) =>
51-
val samResultType = samMeth.info.resultType
52-
if samResultType.isRef(defn.UnitClass) then
53-
// If the result type of the SAM method is Unit, but the result type of the overridden
54-
// methods is not Unit, the bridge will return Object, which is not compatible with Void
55-
// required by LambdaMetaFactory.
56-
// See issue #24573 for details.
57-
samMeth.symbol.allOverriddenSymbols.forall(_.info.resultType.isRef(defn.UnitClass))
58-
else true
59-
case _ => false
60-
6148
/** Is the SAMType `cls` also a SAM under the rules of the JVM? */
6249
def isSam(cls: ClassSymbol)(using Context): Boolean =
6350
cls.isAllOf(NoInitsTrait) &&
@@ -66,8 +53,8 @@ class JavaPlatform extends Platform {
6653
!ExplicitOuter.needsOuterIfReferenced(cls) &&
6754
// Superaccessors already show up as abstract methods here, so no test necessary
6855
cls.typeRef.fields.isEmpty &&
69-
// Check if SAM method will have a compatible bridge for LambdaMetaFactory
70-
samMethodHasCompatibleBridge(cls)
56+
// Check if the SAM can be implemented via LambdaMetaFactory
57+
!TypeErasure.samNeedsExpansion(cls)
7158

7259
/** We could get away with excluding BoxedBooleanClass for the
7360
* purpose of equality testing since it need not compare equal

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end SourceLanguage
7474
* only for isInstanceOf, asInstanceOf: PolyType, TypeParamRef, TypeBounds
7575
*
7676
*/
77-
object TypeErasure {
77+
object TypeErasure:
7878

7979
private def erasureDependsOnArgs(sym: Symbol)(using Context) =
8080
sym == defn.ArrayClass || sym == defn.PairClass || sym.isDerivedValueClass
@@ -586,7 +586,79 @@ object TypeErasure {
586586
defn.FunctionType(n = info.nonErasedParamCount)
587587
}
588588
erasure(functionType(applyInfo))
589-
}
589+
590+
/** Check if LambdaMetaFactory cannot handle the SAM method's required signature adaptation.
591+
*
592+
* When a SAM method overrides other methods, the erased signatures must be compatible
593+
* for LambdaMetaFactory to work. This method returns true if any overridden method
594+
* has an incompatible erased signature that LMF cannot auto-adapt.
595+
*
596+
* The adaptation rules mirror those in `Erasure.Boxing.adaptClosure`:
597+
* - For parameters: primitives and value classes cannot be auto-adapted by LMF
598+
* - For results: value classes and Unit cannot be auto-adapted by LMF
599+
*
600+
* When this returns true, the SAM class must be expanded rather than using LMF.
601+
*
602+
* @param cls The SAM class to check
603+
* @return true if LMF cannot handle the required adaptation
604+
*/
605+
def samNeedsExpansion(cls: ClassSymbol)(using Context): Boolean =
606+
val Seq(samMeth) = cls.typeRef.possibleSamMethods
607+
val samMethSym = samMeth.symbol
608+
val erasedSamInfo = transformInfo(samMethSym, samMeth.info)
609+
// println(i"Checking whether SAM ${cls} needs expansion, erased SAM info: $erasedSamInfo")
610+
611+
val (erasedSamParamTypes, erasedSamResultType) = erasedSamInfo match
612+
case mt: MethodType => (mt.paramInfos, mt.resultType)
613+
case _ => return false
614+
615+
def sameClass(tp1: Type, tp2: Type) = tp1.classSymbol == tp2.classSymbol
616+
617+
/** Can the implementation parameter type `tp` be auto-adapted to a different
618+
* parameter type in the SAM?
619+
*
620+
* For derived value classes, we always need to do the bridging manually.
621+
* For primitives, we cannot rely on auto-adaptation on the JVM because
622+
* the Scala spec requires null to be "unboxed" to the default value of
623+
* the value class, but the adaptation performed by LambdaMetaFactory
624+
* will throw a `NullPointerException` instead.
625+
*/
626+
def autoAdaptedParam(tp: Type) =
627+
!tp.isErasedValueType && !tp.isPrimitiveValueType
628+
629+
/** Can the implementation result type be auto-adapted to a different result
630+
* type in the SAM?
631+
*
632+
* For derived value classes, it's the same story as for parameters.
633+
* For non-Unit primitives, we can actually rely on the `LambdaMetaFactory`
634+
* adaptation, because it only needs to box, not unbox, so no special
635+
* handling of null is required.
636+
*/
637+
def autoAdaptedResult(implResultType: Type) =
638+
!implResultType.isErasedValueType && !(implResultType.classSymbol eq defn.UnitClass)
639+
640+
samMethSym.allOverriddenSymbols.exists { overridden =>
641+
val erasedOverriddenInfo = transformInfo(overridden, overridden.info)
642+
// println(i" comparing to overridden method ${overridden} with erased info: $erasedOverriddenInfo")
643+
erasedOverriddenInfo match
644+
case mt: MethodType =>
645+
val overriddenParamTypes = mt.paramInfos
646+
val overriddenResultType = mt.resultType
647+
648+
val paramAdaptationNeeded =
649+
erasedSamParamTypes.lazyZip(overriddenParamTypes).exists((samType, overriddenType) =>
650+
!sameClass(samType, overriddenType) && (!autoAdaptedParam(samType)
651+
// LambdaMetaFactory cannot auto-adapt between Object and Array types
652+
|| overriddenType.isInstanceOf[JavaArrayType]))
653+
654+
val resultAdaptationNeeded =
655+
!sameClass(erasedSamResultType, overriddenResultType) && !autoAdaptedResult(erasedSamResultType)
656+
657+
paramAdaptationNeeded || resultAdaptationNeeded
658+
case _ => false
659+
}
660+
end samNeedsExpansion
661+
end TypeErasure
590662

591663
import TypeErasure.*
592664

tests/run/i24573.check

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,35 @@
11
1
22
2
33
3
4+
11
5+
12
6+
13
7+
14
8+
15
9+
16
10+
17
11+
18
12+
19
13+
20
14+
21
15+
22
16+
23
17+
24
18+
31
19+
32
20+
41
421
42
5-
hello
6-
world
7-
!!
22+
43
23+
44
24+
45
25+
46
26+
51
27+
52
28+
53
29+
55
30+
56
31+
57
32+
61
33+
62
34+
63
35+
64

tests/run/i24573.scala

Lines changed: 126 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,147 @@
1-
trait Con[-T] extends (T => Unit):
1+
trait ConTU[-T] extends (T => Unit):
22
def apply(t: T): Unit
33

4-
trait Con2[-T] extends (T => Int):
4+
trait ConTI[-T] extends (T => Int):
55
def apply(t: T): Int
66

7-
trait Con3[+R] extends (() => R):
7+
trait ConTS[-T] extends (T => String):
8+
def apply(t: T): String
9+
10+
trait ConIR[+R] extends (Int => R):
11+
def apply(t: Int): R
12+
13+
trait ConSR[+R] extends (String => R):
14+
def apply(t: String): R
15+
16+
trait ConUR[+R] extends (() => R):
817
def apply(): R
918

19+
trait ConII extends (Int => Int):
20+
def apply(t: Int): Int
21+
22+
trait ConSI extends (String => Int):
23+
def apply(t: String): Int
24+
25+
trait ConIS extends (Int => String):
26+
def apply(t: Int): String
27+
28+
trait ConUU extends (() => Unit):
29+
def apply(): Unit
30+
1031
trait F1[-T, +R]:
1132
def apply(t: T): R
1233

13-
trait SF[-T] extends F1[T, Unit]:
34+
trait SFTU[-T] extends F1[T, Unit]:
1435
def apply(t: T): Unit
1536

16-
trait F1U[-T]:
17-
def apply(t: T): Unit
37+
trait SFTI[-T] extends F1[T, Int]:
38+
def apply(t: T): Int
39+
40+
trait SFTS[-T] extends F1[T, String]:
41+
def apply(t: T): String
42+
43+
trait SFIR [+R] extends F1[Int, R]:
44+
def apply(t: Int): R
45+
46+
trait SFSR [+R] extends F1[String, R]:
47+
def apply(t: String): R
48+
49+
trait SFII extends F1[Int, Int]:
50+
def apply(t: Int): Int
51+
52+
trait SFSI extends F1[String, Int]:
53+
def apply(t: String): Int
54+
55+
trait SFIS extends F1[Int, String]:
56+
def apply(t: Int): String
57+
58+
trait SFIU extends F1[Int, Unit]:
59+
def apply(t: Int): Unit
1860

19-
trait SF2 extends F1U[String]:
20-
def apply(t: String): Unit
2161

2262
object Test:
2363
def main(args: Array[String]): Unit =
24-
val f1: (Int => Unit) = i => println(i)
25-
f1(1)
64+
val fIU: (Int => Unit) = (x: Int) => println(x)
65+
fIU(1)
66+
67+
val fIS: (Int => String) = (x: Int) => x.toString
68+
println(fIS(2))
69+
70+
val fUI: (() => Int) = () => 3
71+
println(fUI())
72+
73+
val conITU: ConTU[Int] = (x: Int) => println(x)
74+
conITU(11)
75+
val conITI: ConTI[Int] = (x: Int) => x
76+
println(conITI(12))
77+
val conITS: ConTS[Int] = (x: Int) => x.toString
78+
println(conITS(13))
79+
val conSTS: ConTS[String] = (x: String) => x
80+
println(conSTS("14"))
81+
82+
val conIRS: ConIR[String] = (x: Int) => x.toString
83+
println(conIRS(15))
84+
val conIRI: ConIR[Int] = (x: Int) => x
85+
println(conIRI(16))
86+
val conIRU: ConIR[Unit] = (x: Int) => println(x)
87+
conIRU(17)
88+
89+
val conSRI: ConSR[Int] = (x: String) => x.toInt
90+
println(conSRI("18"))
91+
val conURI: ConUR[Int] = () => 19
92+
println(conURI())
93+
val conURU: ConUR[Unit] = () => println("20")
94+
conURU()
95+
96+
val conII: ConII = (x: Int) => x
97+
println(conII(21))
98+
val conSI: ConSI = (x: String) => x.toInt
99+
println(conSI("22"))
100+
val conIS: ConIS = (x: Int) => x.toString
101+
println(conIS(23))
102+
val conUU: ConUU = () => println("24")
103+
conUU()
104+
105+
val ffIU: F1[Int, Unit] = (x: Int) => println(x)
106+
ffIU(31)
107+
val ffIS: F1[Int, String] = (x: Int) => x.toString
108+
println(ffIS(32))
109+
110+
val sfITU: SFTU[Int] = (x: Int) => println(x)
111+
sfITU(41)
112+
val sfSTU: SFTU[String] = (x: String) => println(x)
113+
sfSTU("42")
26114

27-
val c1: Con[Int] = i => println(i)
28-
c1(2)
115+
val sfITI: SFTI[Int] = (x: Int) => x
116+
println(sfITI(43))
117+
val sfSTI: SFTI[String] = (x: String) => x.toInt
118+
println(sfSTI("44"))
29119

30-
val c2: Con2[Int] = i => { println(i); i }
31-
c2(3)
120+
val sfITS: SFTS[Int] = (x: Int) => x.toString
121+
println(sfITS(45))
122+
val sfSTS: SFTS[String] = (x: String) => x
123+
println(sfSTS("46"))
32124

33-
val c3: Con3[Int] = () => 42
34-
println(c3())
125+
val sfIRI: SFIR[Int] = (x: Int) => x
126+
println(sfIRI(51))
127+
val sfIRS: SFIR[String] = (x: Int) => x.toString
128+
println(sfIRS(52))
129+
val sfIRU: SFIR[Unit] = (x: Int) => println(x)
130+
sfIRU(53)
35131

36-
val c4: Con3[Unit] = () => println("hello")
37-
c4()
132+
val sfSRI: SFSR[Int] = (x: String) => x.toInt
133+
println(sfSRI("55"))
134+
val sfSRS: SFSR[String] = (x: String) => x
135+
println(sfSRS("56"))
136+
val sfSRU: SFSR[Unit] = (x: String) => println(x)
137+
sfSRU("57")
38138

39-
val f5: SF[String] = s => println(s)
40-
f5("world")
139+
val sfII: SFII = (x: Int) => x
140+
println(sfII(61))
141+
val sfSI: SFSI = (x: String) => x.toInt
142+
println(sfSI("62"))
143+
val sfIS: SFIS = (x: Int) => x.toString
144+
println(sfIS(63))
145+
val sfIU: SFIU = (x: Int) => println(x)
146+
sfIU(64)
41147

42-
val f6: SF2 = i => println(i)
43-
f6("!!")

0 commit comments

Comments
 (0)