Skip to content

Commit 6726a78

Browse files
authored
Split ReflectionUtils (#579)
1 parent d2b16df commit 6726a78

File tree

4 files changed

+98
-92
lines changed

4 files changed

+98
-92
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package org.mockito
2+
3+
import org.mockito.invocation.InvocationOnMock
4+
import ru.vyarus.java.generics.resolver.GenericsResolver
5+
6+
import java.lang.reflect.Field
7+
import scala.util.control.NonFatal
8+
9+
/**
10+
* Utility methods for Java reflection operations, particularly for Mockito mocks.
11+
*/
12+
object JavaReflectionUtils {
13+
14+
def resolveWithJavaGenerics(invocation: InvocationOnMock): Option[Class[_]] =
15+
try Some(GenericsResolver.resolve(invocation.getMock.getClass).`type`(invocation.method.getDeclaringClass).method(invocation.method).resolveReturnClass())
16+
catch {
17+
case _: Throwable => None
18+
}
19+
20+
def setFinalStatic(field: Field, newValue: AnyRef): Unit =
21+
try {
22+
// Try to get Unsafe instance (works with both sun.misc.Unsafe and jdk.internal.misc.Unsafe)
23+
val unsafeClass: Class[_] =
24+
try
25+
Class.forName("sun.misc.Unsafe")
26+
catch {
27+
case _: ClassNotFoundException => Class.forName("jdk.internal.misc.Unsafe")
28+
}
29+
30+
val unsafeField = unsafeClass.getDeclaredField("theUnsafe")
31+
unsafeField.setAccessible(true)
32+
val unsafe = unsafeField.get(null)
33+
34+
// Get methods via reflection to handle both Unsafe implementations
35+
val staticFieldBaseMethod = unsafeClass.getMethod("staticFieldBase", classOf[Field])
36+
val staticFieldOffsetMethod = unsafeClass.getMethod("staticFieldOffset", classOf[Field])
37+
val putObjectMethod = unsafeClass.getMethod("putObject", classOf[Object], classOf[Long], classOf[Object])
38+
39+
// Make the field accessible
40+
field.setAccessible(true)
41+
42+
// Get base and offset for the field
43+
val base: Object = staticFieldBaseMethod.invoke(unsafe, field)
44+
val offset: Long = staticFieldOffsetMethod.invoke(unsafe, field).asInstanceOf[Long]
45+
46+
// Set the field value directly
47+
putObjectMethod.invoke(unsafe, base, java.lang.Long.valueOf(offset), newValue)
48+
} catch {
49+
case NonFatal(e) =>
50+
throw new IllegalStateException(s"Cannot modify final field ${field.getName}", e)
51+
}
52+
53+
}

common/src/main/scala/org/mockito/MockitoAPI.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
package org.mockito
1313

1414
import org.mockito.Answers.CALLS_REAL_METHODS
15-
import org.mockito.ReflectionUtils.InvocationOnMockOps
1615
import org.mockito.internal.configuration.plugins.Plugins.getMockMaker
1716
import org.mockito.internal.creation.MockSettingsImpl
1817
import org.mockito.internal.exceptions.Reporter.notAMockPassedToVerifyNoMoreInteractions
@@ -453,7 +452,6 @@ private[mockito] trait DoSomething {
453452
}
454453

455454
private[mockito] trait MockitoEnhancer extends MockCreator {
456-
implicit val invocationOps: InvocationOnMock => InvocationOnMockOps = InvocationOps
457455

458456
/**
459457
* Delegates to <code>Mockito.mock(type: Class[T])</code> It provides a nicer API as you can, for instance, do <code>mock[MyClass]</code> instead of
@@ -630,9 +628,9 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
630628
(settings: MockCreationSettings[O], pt: Prettifier) => ThreadAwareMockHandler(settings, realImpl)(pt)
631629
)
632630

633-
ReflectionUtils.setFinalStatic(moduleField, threadAwareMock)
631+
JavaReflectionUtils.setFinalStatic(moduleField, threadAwareMock)
634632
try block
635-
finally ReflectionUtils.setFinalStatic(moduleField, realImpl)
633+
finally JavaReflectionUtils.setFinalStatic(moduleField, realImpl)
636634
}
637635
}
638636
}
Lines changed: 29 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
package org.mockito
22

3-
import java.lang.reflect.{ Field, Method, Modifier }
4-
5-
import org.mockito.internal.ValueClassWrapper
3+
import org.mockito.JavaReflectionUtils.resolveWithJavaGenerics
64
import org.mockito.invocation.InvocationOnMock
75
import org.scalactic.TripleEquals._
8-
import ru.vyarus.java.generics.resolver.GenericsResolver
96

7+
import java.lang.reflect.Method
108
import scala.reflect.ClassTag
119
import scala.reflect.internal.Symbols
1210
import scala.util.{ Try => uTry }
13-
import scala.util.control.NonFatal
1411

1512
object ReflectionUtils {
1613
import scala.reflect.runtime.{ universe => ru }
@@ -23,58 +20,37 @@ object ReflectionUtils {
2320
def methodToJava(sym: Symbols#MethodSymbol): Method
2421
}]
2522

26-
def listToTuple(l: List[Object]): Any =
27-
l match {
28-
case Nil => Nil
29-
case h :: Nil => h
30-
case _ => Class.forName(s"scala.Tuple${l.size}").getDeclaredConstructors.head.newInstance(l: _*)
31-
}
32-
33-
implicit class InvocationOnMockOps(val invocation: InvocationOnMock) extends AnyVal {
34-
def mock[M]: M = invocation.getMock.asInstanceOf[M]
35-
def method: Method = invocation.getMethod
36-
def arg[A: ValueClassWrapper](index: Int): A = ValueClassWrapper[A].wrapAs[A](invocation.getArgument(index))
37-
def args: List[Any] = invocation.getArguments.toList
38-
def callRealMethod[R](): R = invocation.callRealMethod.asInstanceOf[R]
39-
def argsAsTuple: Any = listToTuple(args.map(_.asInstanceOf[Object]))
40-
41-
def returnType: Class[_] = {
42-
val javaReturnType = method.getReturnType
23+
private[mockito] def returnType(invocation: InvocationOnMock): Class[_] = {
24+
val javaReturnType = invocation.method.getReturnType
4325

44-
if (javaReturnType == classOf[Object])
45-
resolveWithScalaGenerics
46-
.orElse(resolveWithJavaGenerics)
47-
.getOrElse(javaReturnType)
48-
else javaReturnType
49-
}
26+
if (javaReturnType == classOf[Object])
27+
resolveWithScalaGenerics(invocation)
28+
.orElse(resolveWithJavaGenerics(invocation))
29+
.getOrElse(javaReturnType)
30+
else javaReturnType
31+
}
5032

51-
def returnsValueClass: Boolean = findTypeSymbol.exists(_.returnType.typeSymbol.isDerivedValueClass)
33+
private[mockito] def returnsValueClass(invocation: InvocationOnMock): Boolean =
34+
findTypeSymbol(invocation).exists(_.returnType.typeSymbol.isDerivedValueClass)
5235

53-
private def resolveWithScalaGenerics: Option[Class[_]] =
54-
uTry {
55-
findTypeSymbol
56-
.filter(_.returnType.typeSymbol.isClass)
57-
.map(_.asMethod.returnType.typeSymbol.asClass)
58-
.map(mirror.runtimeClass)
59-
}.toOption.flatten
60-
61-
private def findTypeSymbol =
62-
uTry {
63-
mirror
64-
.classSymbol(method.getDeclaringClass)
65-
.info
66-
.decls
67-
.collectFirst {
68-
case symbol if isNonConstructorMethod(symbol) && customMirror.methodToJava(symbol) === method => symbol
69-
}
70-
}.toOption.flatten
36+
private def resolveWithScalaGenerics(invocation: InvocationOnMock): Option[Class[_]] =
37+
uTry {
38+
findTypeSymbol(invocation)
39+
.filter(_.returnType.typeSymbol.isClass)
40+
.map(_.asMethod.returnType.typeSymbol.asClass)
41+
.map(mirror.runtimeClass)
42+
}.toOption.flatten
7143

72-
private def resolveWithJavaGenerics: Option[Class[_]] =
73-
try Some(GenericsResolver.resolve(invocation.getMock.getClass).`type`(method.getDeclaringClass).method(method).resolveReturnClass())
74-
catch {
75-
case _: Throwable => None
76-
}
77-
}
44+
private def findTypeSymbol(invocation: InvocationOnMock) =
45+
uTry {
46+
mirror
47+
.classSymbol(invocation.method.getDeclaringClass)
48+
.info
49+
.decls
50+
.collectFirst {
51+
case symbol if isNonConstructorMethod(symbol) && customMirror.methodToJava(symbol) === invocation.method => symbol
52+
}
53+
}.toOption.flatten
7854

7955
private def isNonConstructorMethod(d: ru.Symbol): Boolean = d.isMethod && !d.isConstructor
8056

@@ -113,37 +89,4 @@ object ReflectionUtils {
11389
.getOrElse(Seq.empty)
11490
}
11591

116-
def setFinalStatic(field: Field, newValue: AnyRef): Unit =
117-
try {
118-
// Try to get Unsafe instance (works with both sun.misc.Unsafe and jdk.internal.misc.Unsafe)
119-
val unsafeClass: Class[_] =
120-
try
121-
Class.forName("sun.misc.Unsafe")
122-
catch {
123-
case _: ClassNotFoundException => Class.forName("jdk.internal.misc.Unsafe")
124-
}
125-
126-
val unsafeField = unsafeClass.getDeclaredField("theUnsafe")
127-
unsafeField.setAccessible(true)
128-
val unsafe = unsafeField.get(null)
129-
130-
// Get methods via reflection to handle both Unsafe implementations
131-
val staticFieldBaseMethod = unsafeClass.getMethod("staticFieldBase", classOf[Field])
132-
val staticFieldOffsetMethod = unsafeClass.getMethod("staticFieldOffset", classOf[Field])
133-
val putObjectMethod = unsafeClass.getMethod("putObject", classOf[Object], classOf[Long], classOf[Object])
134-
135-
// Make the field accessible
136-
field.setAccessible(true)
137-
138-
// Get base and offset for the field
139-
val base: Object = staticFieldBaseMethod.invoke(unsafe, field)
140-
val offset: Long = staticFieldOffsetMethod.invoke(unsafe, field).asInstanceOf[Long]
141-
142-
// Set the field value directly
143-
putObjectMethod.invoke(unsafe, base, java.lang.Long.valueOf(offset), newValue)
144-
} catch {
145-
case NonFatal(e) =>
146-
throw new IllegalStateException(s"Cannot modify final field ${field.getName}", e)
147-
}
148-
14992
}

common/src/main/scala/org/mockito/mockito.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package org
22

33
import java.lang.reflect.Method
44

5-
import org.mockito.ReflectionUtils.InvocationOnMockOps
65
import org.mockito.internal.{ ValueClassExtractor, ValueClassWrapper }
76
import org.mockito.invocation.InvocationOnMock
87
import org.mockito.stubbing.ScalaAnswer
@@ -21,7 +20,20 @@ package object mockito {
2120

2221
def clazz[T](implicit classTag: ClassTag[T]): Class[T] = classTag.runtimeClass.asInstanceOf[Class[T]]
2322

24-
implicit val InvocationOps: InvocationOnMock => InvocationOnMockOps = new InvocationOnMockOps(_)
23+
implicit class InvocationOnMockOps(val invocation: InvocationOnMock) {
24+
def mock[M]: M = invocation.getMock.asInstanceOf[M]
25+
def method: Method = invocation.getMethod
26+
def arg[A: ValueClassWrapper](index: Int): A = ValueClassWrapper[A].wrapAs[A](invocation.getArgument(index))
27+
def args: List[Any] = invocation.getArguments.toList
28+
def callRealMethod[R](): R = invocation.callRealMethod.asInstanceOf[R]
29+
def argsAsTuple: Any = args.map(_.asInstanceOf[Object]) match {
30+
case Nil => Nil
31+
case h :: Nil => h
32+
case l => Class.forName(s"scala.Tuple${l.size}").getDeclaredConstructors.head.newInstance(l: _*)
33+
}
34+
def returnType: Class[_] = ReflectionUtils.returnType(invocation)
35+
def returnsValueClass: Boolean = ReflectionUtils.returnsValueClass(invocation)
36+
}
2537

2638
def invocationToAnswer[T: ValueClassExtractor](f: InvocationOnMock => T): ScalaAnswer[T] =
2739
ScalaAnswer.lift(f.andThen(ValueClassExtractor[T].extractAs[T]))

0 commit comments

Comments
 (0)