Skip to content

Commit 96cec5c

Browse files
committed
SAM Overloading and existential bounds
In some cases existential bounds can be simplified without losing precision. For example: trait Blargle[T] { def compare(a: T, b: T): Int } trait Test { def foo(a: Blargle[_ >: String]): Int } can be simplified to: trait Test { def foo(a: Blargle[String]): Int } see: scala/scala#4101 #SCL8956 Fixed
1 parent 67153b0 commit 96cec5c

File tree

2 files changed

+131
-11
lines changed

2 files changed

+131
-11
lines changed

src/org/jetbrains/plugins/scala/lang/psi/ScalaPsiUtil.scala

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,7 +2341,7 @@ object ScalaPsiUtil {
23412341
* @see SCL-6140
23422342
* @see https://github.com/scala/scala/pull/3018/
23432343
*/
2344-
def toSAMType(expected: ScType, scope: GlobalSearchScope): Option[ScType] = {
2344+
def toSAMType(expected: ScType, scalaScope: GlobalSearchScope): Option[ScType] = {
23452345

23462346
def constructorValidForSAM(constructors: Array[PsiMethod]): Boolean = {
23472347
//primary constructor (if any) must be public, no-args, not overloaded
@@ -2371,8 +2371,14 @@ object ScalaPsiUtil {
23712371
!abst.head.hasTypeParameters
23722372

23732373
if (valid) {
2374-
abst.head.getType() match {
2375-
case Success(tp, _) => Some(sub.subst(tp))
2374+
val fun = abst.head
2375+
fun.getType() match {
2376+
case Success(tp, _) =>
2377+
val subbed = sub.subst(tp)
2378+
extrapolateWildcardBounds(subbed, expected, fun.getProject, scalaScope) match {
2379+
case s@Some(_) => s
2380+
case _ => Some(subbed)
2381+
}
23762382
case _ => None
23772383
}
23782384
} else None
@@ -2384,15 +2390,63 @@ object ScalaPsiUtil {
23842390
//need to generate ScType for Java method
23852391
val method = abst.head
23862392
val project = method.getProject
2387-
val returnType: ScType = ScType.create(method.getReturnType, project, scope)
2393+
val returnType: ScType = ScType.create(method.getReturnType, project, scalaScope)
23882394
val params: Array[ScType] = method.getParameterList.getParameters.map {
2389-
param: PsiParameter => ScType.create(param.getTypeElement.getType, project, scope)
2395+
param: PsiParameter => ScType.create(param.getTypeElement.getType, project, scalaScope)
2396+
}
2397+
val fun = ScFunctionType(returnType, params)(project, scalaScope)
2398+
val subbed = sub.subst(fun)
2399+
extrapolateWildcardBounds(subbed, expected, project, scalaScope) match {
2400+
case s@Some(_) => s
2401+
case _ => Some(subbed)
23902402
}
2391-
val result = ScFunctionType(returnType, params)(project, scope)
2392-
Some(sub.subst(result))
23932403
} else None
23942404
}
23952405
case None => None
23962406
}
23972407
}
2408+
2409+
/**
2410+
* In some cases existential bounds can be simplified without losing precision
2411+
*
2412+
* trait Comparinator[T] { def compare(a: T, b: T): Int }
2413+
*
2414+
* trait Test {
2415+
* def foo(a: Comparinator[_ >: String]): Int
2416+
* }
2417+
*
2418+
* can be simplified to:
2419+
*
2420+
* trait Test {
2421+
* def foo(a: Comparinator[String]): Int
2422+
* }
2423+
*
2424+
* @see https://github.com/scala/scala/pull/4101
2425+
* @see SCL-8956
2426+
*/
2427+
private def extrapolateWildcardBounds(tp: ScType, expected: ScType, proj: Project, scope: GlobalSearchScope): Option[ScType] = {
2428+
expected match {
2429+
case ScExistentialType(ScParameterizedType(expectedDesignator, _), wildcards) =>
2430+
tp match {
2431+
case ScFunctionType(retTp, params) =>
2432+
def convertParameter(tpArg: ScType, variance: Int): ScType = {
2433+
wildcards.find(_.name == tpArg.canonicalText) match {
2434+
case Some(wildcard) =>
2435+
(wildcard.lowerBound, wildcard.upperBound) match {
2436+
case (lo, Any) if variance == ScTypeParam.Contravariant => lo
2437+
case (Nothing, hi) if variance == ScTypeParam.Covariant => hi
2438+
case _ => tpArg
2439+
}
2440+
case _ => tpArg
2441+
}
2442+
}
2443+
//parameter clauses are contravariant positions, return types are covariant positions
2444+
val newParams = params.map(convertParameter(_, ScTypeParam.Contravariant))
2445+
val newRetTp = convertParameter(retTp, ScTypeParam.Covariant)
2446+
Some(ScFunctionType(newRetTp, newParams)(proj, scope))
2447+
case _ => None
2448+
}
2449+
case _ => None
2450+
}
2451+
}
23982452
}

test/org/jetbrains/plugins/scala/annotator/SingleAbstractMethodTest.scala

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,83 @@ class SingleAbstractMethodTest extends ScalaLightPlatformCodeInsightTestCaseAdap
293293
checkCodeHasNoErrors(code)
294294
}
295295

296-
def checkCodeHasNoErrors(code: String) {
297-
assertMatches(messages(code)) {
296+
def testExistentialBounds(): Unit = {
297+
val code =
298+
"""
299+
|trait Blargle[T] {
300+
| def foo(a: T): String
301+
|}
302+
|
303+
|def f(b: Blargle[_ >: Int]) = -1
304+
|f(s => s.toString)
305+
|
306+
|def g[T](b: Blargle[_ >: T]) = -1
307+
|g((s: String) => s)
308+
|
309+
|trait Blergh[T] {
310+
| def foo(): T
311+
|}
312+
|
313+
|def h[T](b: Blergh[_ <: T]) = -1
314+
|h(() => "")
315+
|def i(b: Blergh[_ <: String]) = -1
316+
|i(() => "")
317+
|
318+
""".stripMargin
319+
checkCodeHasNoErrors(code)
320+
}
321+
322+
def testOverload(): Unit = {
323+
val code =
324+
"""
325+
|trait SAMOverload[A] {
326+
| def foo(s: A): Int = ???
327+
|}
328+
|
329+
|def f[T](s: T): Unit = ()
330+
|def f[T](s: T, a: SAMOverload[_ >: T]) = ()
331+
|f("", (s: String) => 2)
332+
|
333+
""".stripMargin
334+
checkCodeHasNoErrors(code)
335+
}
336+
337+
def testJavaSAM(): Unit = {
338+
val scalaCode = "new ObservableCopy(1).mapFunc(x => x + 1)"
339+
val javaCode =
340+
"""
341+
|public interface Func1<T, R> {
342+
| R call(T t);
343+
|}
344+
|
345+
|public class ObservableCopy<T> {
346+
| public ObservableCopy(T t) {}
347+
|
348+
| public final <R> ObservableCopy<R> mapFunc(Func1<? super T, ? extends R> func) {
349+
| return null;
350+
| }
351+
|}
352+
|
353+
""".stripMargin
354+
checkCodeHasNoErrors(scalaCode, Some(javaCode))
355+
}
356+
357+
def checkCodeHasNoErrors(scalaCode: String, javaCode: Option[String] = None) {
358+
assertMatches(messages(scalaCode, javaCode)) {
298359
case Nil =>
299360
}
300361
}
301362

302-
def messages(code: String): List[Message] = {
363+
def messages(@Language("Scala") scalaCode: String, javaCode: Option[String] = None): List[Message] = {
364+
javaCode match {
365+
case Some(s) => configureFromFileTextAdapter("dummy.java", s)
366+
case _ =>
367+
}
368+
303369
val annotator = new ScalaAnnotator() {}
304370
val mock = new AnnotatorHolderMock
305371

306-
val parse: ScalaFile = parseText(code)
372+
val parse: ScalaFile = parseText(scalaCode)
307373

308374
parse.depthFirst.foreach(annotator.annotate(_, mock))
309375

0 commit comments

Comments
 (0)