Skip to content

Commit 439a17d

Browse files
authored
Support signature polymorphic methods (MethodHandle and VarHandle) (#16225)
fixes #11332
2 parents d7e4f94 + 1b0b830 commit 439a17d

File tree

14 files changed

+193
-8
lines changed

14 files changed

+193
-8
lines changed

compiler/src/dotty/tools/backend/jvm/CoreBTypes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: DottyBackendInterface]](val bTyp
134134

135135
private lazy val jliCallSiteRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.CallSite])
136136
private lazy val jliLambdaMetafactoryRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.LambdaMetafactory])
137-
private lazy val jliMethodHandleRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodHandle])
138-
private lazy val jliMethodHandlesLookupRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodHandles.Lookup])
137+
private lazy val jliMethodHandleRef : ClassBType = classBTypeFromSymbol(defn.MethodHandleClass)
138+
private lazy val jliMethodHandlesLookupRef : ClassBType = classBTypeFromSymbol(defn.MethodHandlesLookupClass)
139139
private lazy val jliMethodTypeRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodType])
140140
private lazy val jliStringConcatFactoryRef : ClassBType = classBTypeFromSymbol(requiredClass("java.lang.invoke.StringConcatFactory")) // since JDK 9
141141
private lazy val srLambdaDeserialize : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.LambdaDeserialize])

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,10 @@ class Definitions {
734734
}
735735
def JavaEnumType = JavaEnumClass.typeRef
736736

737+
@tu lazy val MethodHandleClass: ClassSymbol = requiredClass("java.lang.invoke.MethodHandle")
738+
@tu lazy val MethodHandlesLookupClass: ClassSymbol = requiredClass("java.lang.invoke.MethodHandles.Lookup")
739+
@tu lazy val VarHandleClass: ClassSymbol = requiredClass("java.lang.invoke.VarHandle")
740+
737741
@tu lazy val StringBuilderClass: ClassSymbol = requiredClass("scala.collection.mutable.StringBuilder")
738742
@tu lazy val MatchErrorClass : ClassSymbol = requiredClass("scala.MatchError")
739743
@tu lazy val ConversionClass : ClassSymbol = requiredClass("scala.Conversion").typeRef.symbol.asClass

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,26 @@ object SymDenotations {
960960

961961
def isSkolem: Boolean = name == nme.SKOLEM
962962

963+
// Java language spec: https://docs.oracle.com/javase/specs/jls/se11/html/jls-15.html#jls-15.12.3
964+
// Scala 2 spec: https://scala-lang.org/files/archive/spec/2.13/06-expressions.html#signature-polymorphic-methods
965+
def isSignaturePolymorphic(using Context): Boolean =
966+
containsSignaturePolymorphic
967+
&& is(JavaDefined)
968+
&& hasAnnotation(defn.NativeAnnot)
969+
&& atPhase(typerPhase)(symbol.denot).paramSymss.match
970+
case List(List(p)) => p.info.isRepeatedParam
971+
case _ => false
972+
973+
def containsSignaturePolymorphic(using Context): Boolean =
974+
maybeOwner == defn.MethodHandleClass
975+
|| maybeOwner == defn.VarHandleClass
976+
977+
def originalSignaturePolymorphic(using Context): Denotation =
978+
if containsSignaturePolymorphic && !isSignaturePolymorphic then
979+
val d = owner.info.member(name)
980+
if d.symbol.isSignaturePolymorphic then d else NoDenotation
981+
else NoDenotation
982+
963983
def isInlineMethod(using Context): Boolean =
964984
isAllOf(InlineMethod, butNot = Accessor)
965985

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,13 @@ class TreePickler(pickler: TastyPickler) {
426426
writeByte(THROW)
427427
pickleTree(args.head)
428428
}
429+
else if fun.symbol.originalSignaturePolymorphic.exists then
430+
writeByte(APPLYsigpoly)
431+
withLength {
432+
pickleTree(fun)
433+
pickleType(fun.tpe.widenTermRefExpr, richTypes = true) // this widens to a MethodType, so need richTypes
434+
args.foreach(pickleTree)
435+
}
429436
else {
430437
writeByte(APPLY)
431438
withLength {

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,12 @@ class TreeUnpickler(reader: TastyReader,
12361236
else tpd.Apply(fn, args)
12371237
case TYPEAPPLY =>
12381238
tpd.TypeApply(readTerm(), until(end)(readTpt()))
1239+
case APPLYsigpoly =>
1240+
val fn = readTerm()
1241+
val methType = readType()
1242+
val args = until(end)(readTerm())
1243+
val fun2 = typer.Applications.retypeSignaturePolymorphicFn(fn, methType)
1244+
tpd.Apply(fun2, args)
12391245
case TYPED =>
12401246
val expr = readTerm()
12411247
val tpt = readTpt()

compiler/src/dotty/tools/dotc/semanticdb/SemanticSymbolBuilder.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ class SemanticSymbolBuilder:
7474
def addOwner(owner: Symbol): Unit =
7575
if !owner.isRoot then addSymName(b, owner)
7676

77-
def addOverloadIdx(sym: Symbol): Unit =
77+
def addOverloadIdx(initSym: Symbol): Unit =
78+
// revert from the compiler-generated overload of the signature polymorphic method
79+
val sym = initSym.originalSignaturePolymorphic.symbol.orElse(initSym)
7880
val decls =
7981
val decls0 = sym.owner.info.decls.lookupAll(sym.name)
8082
if sym.owner.isAllOf(JavaModule) then

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ abstract class Recheck extends Phase, SymTransformer:
261261
mt.instantiate(argTypes)
262262

263263
def recheckApply(tree: Apply, pt: Type)(using Context): Type =
264-
val funtpe = recheck(tree.fun)
264+
val funTp = recheck(tree.fun)
265+
// reuse the tree's type on signature polymorphic methods, instead of using the (wrong) rechecked one
266+
val funtpe = if tree.fun.symbol.originalSignaturePolymorphic.exists then tree.fun.tpe else funTp
265267
funtpe.widen match
266268
case fntpe: MethodType =>
267269
assert(fntpe.paramInfos.hasSameLengthAs(tree.args))

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import Inferencing._
2323
import reporting._
2424
import transform.TypeUtils._
2525
import transform.SymUtils._
26-
import Nullables._
26+
import Nullables._, NullOpsDecorator.*
2727
import config.Feature
2828

2929
import collection.mutable
@@ -340,6 +340,12 @@ object Applications {
340340
val getter = findDefaultGetter(fn, n, testOnly)
341341
if getter.isEmpty then getter
342342
else spliceMeth(getter.withSpan(fn.span), fn)
343+
344+
def retypeSignaturePolymorphicFn(fun: Tree, methType: Type)(using Context): Tree =
345+
val sym1 = fun.symbol
346+
val flags2 = sym1.flags | NonMember // ensures Select typing doesn't let TermRef#withPrefix revert the type
347+
val sym2 = sym1.copy(info = methType, flags = flags2) // symbol not entered, to avoid overload resolution problems
348+
fun.withType(sym2.termRef)
343349
}
344350

345351
trait Applications extends Compatibility {
@@ -936,6 +942,21 @@ trait Applications extends Compatibility {
936942
/** Type application where arguments come from prototype, and no implicits are inserted */
937943
def simpleApply(fun1: Tree, proto: FunProto)(using Context): Tree =
938944
methPart(fun1).tpe match {
945+
case funRef: TermRef if funRef.symbol.isSignaturePolymorphic =>
946+
// synthesize a method type based on the types at the call site.
947+
// one can imagine the original signature-polymorphic method as
948+
// being infinitely overloaded, with each individual overload only
949+
// being brought into existence as needed
950+
val originalResultType = funRef.symbol.info.resultType.stripNull
951+
val resultType =
952+
if !originalResultType.isRef(defn.ObjectClass) then originalResultType
953+
else AvoidWildcardsMap()(proto.resultType.deepenProtoTrans) match
954+
case SelectionProto(nme.asInstanceOf_, PolyProto(_, resTp), _, _) => resTp
955+
case resTp if isFullyDefined(resTp, ForceDegree.all) => resTp
956+
case _ => defn.ObjectType
957+
val methType = MethodType(proto.typedArgs().map(_.tpe.widen), resultType)
958+
val fun2 = Applications.retypeSignaturePolymorphicFn(fun1, methType)
959+
simpleApply(fun2, proto)
939960
case funRef: TermRef =>
940961
val app = ApplyTo(tree, fun1, funRef, proto, pt)
941962
convertNewGenericArray(

project/Build.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,9 +1823,10 @@ object Build {
18231823
settings(disableDocSetting).
18241824
settings(
18251825
versionScheme := Some("semver-spec"),
1826-
if (mode == Bootstrapped) {
1827-
commonMiMaSettings
1828-
} else {
1826+
if (mode == Bootstrapped) Def.settings(
1827+
commonMiMaSettings,
1828+
mimaBinaryIssueFilters ++= MiMaFilters.TastyCore,
1829+
) else {
18291830
Nil
18301831
}
18311832
)

project/MiMaFilters.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@ object MiMaFilters {
2424
ProblemFilters.exclude[MissingClassProblem]("scala.caps$Pure"),
2525
ProblemFilters.exclude[MissingClassProblem]("scala.caps$unsafe$"),
2626
)
27+
val TastyCore: Seq[ProblemFilter] = Seq(
28+
ProblemFilters.exclude[MissingMethodProblem]("dotty.tools.tasty.TastyFormat.APPLYsigpoly"),
29+
)
2730
}

tasty/src/dotty/tools/tasty/TastyFormat.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Standard-Section: "ASTs" TopLevelStat*
9191
THROW throwableExpr_Term -- throw throwableExpr
9292
NAMEDARG paramName_NameRef arg_Term -- paramName = arg
9393
APPLY Length fn_Term arg_Term* -- fn(args)
94+
APPLYsigpoly Length fn_Term meth_Type arg_Term* -- The application of a signature-polymorphic method
9495
TYPEAPPLY Length fn_Term arg_Type* -- fn[args]
9596
SUPER Length this_Term mixinTypeIdent_Tree? -- super[mixin]
9697
TYPED Length expr_Term ascriptionType_Term -- expr: ascription
@@ -578,6 +579,7 @@ object TastyFormat {
578579
// final val ??? = 178
579580
// final val ??? = 179
580581
final val METHODtype = 180
582+
final val APPLYsigpoly = 181
581583

582584
final val MATCHtype = 190
583585
final val MATCHtpt = 191
@@ -744,6 +746,7 @@ object TastyFormat {
744746
case BOUNDED => "BOUNDED"
745747
case APPLY => "APPLY"
746748
case TYPEAPPLY => "TYPEAPPLY"
749+
case APPLYsigpoly => "APPLYsigpoly"
747750
case NEW => "NEW"
748751
case THROW => "THROW"
749752
case TYPED => "TYPED"

tests/explicit-nulls/run/i11332.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// scalajs: --skip
2+
import scala.language.unsafeNulls
3+
4+
import java.lang.invoke._, MethodType.methodType
5+
6+
// A copy of tests/run/i11332.scala
7+
// to test the bootstrap minimisation which failed
8+
// (because bootstrap runs under explicit nulls)
9+
class Foo:
10+
def neg(x: Int): Int = -x
11+
12+
object Test:
13+
def main(args: Array[String]): Unit =
14+
val l = MethodHandles.lookup()
15+
val self = new Foo()
16+
17+
val res4 = {
18+
l // explicit chain method call - previously derivedSelect broke the type
19+
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
20+
.invokeExact(self, 4): Int
21+
}
22+
assert(-4 == res4)

tests/run/i11332.scala

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// scalajs: --skip
2+
import scala.language.unsafeNulls
3+
4+
import java.lang.invoke._, MethodType.methodType
5+
6+
class Foo:
7+
def neg(x: Int): Int = -x
8+
def rev(s: String): String = s.reverse
9+
def over(l: Long): String = "long"
10+
def over(i: Int): String = "int"
11+
def unit(s: String): Unit = ()
12+
def obj(s: String): Object = s
13+
14+
object Test:
15+
def main(args: Array[String]): Unit =
16+
val l = MethodHandles.lookup()
17+
val self = new Foo()
18+
val mhNeg = l.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
19+
val mhRev = l.findVirtual(classOf[Foo], "rev", methodType(classOf[String], classOf[String]))
20+
val mhOverL = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Long]))
21+
val mhOverI = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Int]))
22+
val mhUnit = l.findVirtual(classOf[Foo], "unit", methodType(classOf[Unit], classOf[String]))
23+
val mhObj = l.findVirtual(classOf[Foo], "obj", methodType(classOf[Any], classOf[String]))
24+
val mhCL = l.findStatic(classOf[ClassLoader], "getPlatformClassLoader", methodType(classOf[ClassLoader]))
25+
26+
assert(-42 == (mhNeg.invokeExact(self, 42): Int))
27+
assert(-33 == (mhNeg.invokeExact(self, 33): Int))
28+
29+
assert("oof" == (mhRev.invokeExact(self, "foo"): String))
30+
assert("rab" == (mhRev.invokeExact(self, "bar"): String))
31+
32+
assert("long" == (mhOverL.invokeExact(self, 1L): String))
33+
assert("int" == (mhOverI.invokeExact(self, 1): String))
34+
35+
assert(-3 == (id(mhNeg.invokeExact(self, 3)): Int))
36+
expectWrongMethod(mhNeg.invokeExact(self, 4))
37+
38+
{ mhUnit.invokeExact(self, "hi"): Unit; () } // explicit block
39+
val hi2: Unit = mhUnit.invokeExact(self, "hi2")
40+
assert((()) == hi2)
41+
def hi3: Unit = mhUnit.invokeExact(self, "hi3")
42+
assert((()) == hi3)
43+
44+
{ mhObj.invokeExact(self, "any"); () } // explicit block
45+
val any2 = mhObj.invokeExact(self, "any2")
46+
assert("any2" == any2)
47+
def any3 = mhObj.invokeExact(self, "any3")
48+
assert("any3" == any3)
49+
50+
assert(null != (mhCL.invoke(): ClassLoader))
51+
assert(null != (mhCL.invoke().asInstanceOf[ClassLoader]: ClassLoader))
52+
assert(null != (mhCL.invokeExact(): ClassLoader))
53+
assert(null != (mhCL.invokeExact().asInstanceOf[ClassLoader]: ClassLoader))
54+
55+
expectWrongMethod {
56+
l // explicit chain method call
57+
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
58+
.invokeExact(self, 3)
59+
}
60+
val res4 = {
61+
l // explicit chain method call
62+
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
63+
.invokeExact(self, 4): Int
64+
}
65+
assert(-4 == res4)
66+
67+
def id[T](x: T): T = x
68+
69+
def expectWrongMethod(op: => Any) = try {
70+
op
71+
throw new AssertionError("expected operation to fail but it didn't")
72+
} catch case expected: WrongMethodTypeException => ()

tests/run/t12348.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// test: -jvm 11+
2+
// scalajs: --skip
3+
import java.lang.invoke._
4+
import scala.runtime.IntRef
5+
6+
object Test {
7+
def main(args: Array[String]): Unit = {
8+
val ref = new scala.runtime.IntRef(0)
9+
val varHandle = MethodHandles.lookup()
10+
.in(classOf[IntRef])
11+
.findVarHandle(classOf[IntRef], "elem", classOf[Int])
12+
assert(0 == (varHandle.getAndSet(ref, 1): Int))
13+
assert(1 == (varHandle.getAndSet(ref, 2): Int))
14+
assert(2 == ref.elem)
15+
16+
assert((()) == (varHandle.set(ref, 3): Any))
17+
assert(3 == (varHandle.get(ref): Int))
18+
19+
assert(true == (varHandle.compareAndSet(ref, 3, 4): Any))
20+
assert(4 == (varHandle.get(ref): Int))
21+
}
22+
}

0 commit comments

Comments
 (0)