Skip to content

Commit

Permalink
junit5: override test discovery (#3983)
Browse files Browse the repository at this point in the history
This changes discovery of test classes for Junit5, to be in line with
that of sbt-jupiter-interface.

Closes #3910
  • Loading branch information
jodersky authored Nov 18, 2024
1 parent e5c8778 commit 15e4019
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 0 deletions.
55 changes: 55 additions & 0 deletions scalalib/src/mill/scalalib/TestModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,61 @@ object TestModule {
override def ivyDeps: T[Agg[Dep]] = Task {
super.ivyDeps() ++ Agg(ivy"${mill.scalalib.api.Versions.jupiterInterface}")
}

/**
* Overridden since Junit5 has its own discovery mechanism.
*
* This is basically a re-implementation of sbt's plugin for Junit5 test
* discovery mechanism. See
* https://github.com/sbt/sbt-jupiter-interface/blob/468d4f31f1f6ce8529fff8a8804dd733974c7686/src/plugin/src/main/scala/com/github/sbt/junit/jupiter/sbt/JupiterPlugin.scala#L97C15-L118
* for details.
*
* Note that we access the test discovery via reflection, to avoid mill
* itself having a dependency on Junit5. Hence, if you remove the
* `sbt-jupiter-interface` dependency from `ivyDeps`, make sure to also
* override this method.
*/
override def discoveredTestClasses: T[Seq[String]] = T {
Jvm.inprocess(
runClasspath().map(_.path),
classLoaderOverrideSbtTesting = true,
isolated = true,
closeContextClassLoaderWhenDone = true,
cl => {
val builderClass: Class[_] =
cl.loadClass("com.github.sbt.junit.jupiter.api.JupiterTestCollector$Builder")
val builder = builderClass.getConstructor().newInstance()

builderClass.getMethod("withClassDirectory", classOf[java.io.File]).invoke(
builder,
compile().classes.path.wrapped.toFile
)
builderClass.getMethod("withRuntimeClassPath", classOf[Array[java.net.URL]]).invoke(
builder,
testClasspath().map(_.path.wrapped.toUri().toURL()).toArray
)
builderClass.getMethod("withClassLoader", classOf[ClassLoader]).invoke(builder, cl)

val testCollector = builderClass.getMethod("build").invoke(builder)
val testCollectorClass =
cl.loadClass("com.github.sbt.junit.jupiter.api.JupiterTestCollector")

val result = testCollectorClass.getMethod("collectTests").invoke(testCollector)
val resultClass =
cl.loadClass("com.github.sbt.junit.jupiter.api.JupiterTestCollector$Result")

val items = resultClass.getMethod(
"getDiscoveredTests"
).invoke(result).asInstanceOf[java.util.List[_]]
val itemClass = cl.loadClass("com.github.sbt.junit.jupiter.api.JupiterTestCollector$Item")

import scala.jdk.CollectionConverters._
items.asScala.map { item =>
itemClass.getMethod("getFullyQualifiedClassName").invoke(item).asInstanceOf[String]
}.toSeq
}
)
}
}

/**
Expand Down
6 changes: 6 additions & 0 deletions scalalib/test/resources/junit5/test/src/qux/FooTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package qux;

// this class should not be detected as a test
public class FooTests {

}
17 changes: 17 additions & 0 deletions scalalib/test/resources/junit5/test/src/qux/QuxTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package qux;

import static org.junit.jupiter.api.Assertions.assertTrue;

import org.junit.jupiter.api.Test;

public class QuxTests {

@Test
public void hello() {
assertTrue(true);
}

}

// this class should not be detected as a test
class Dummy{}
30 changes: 30 additions & 0 deletions scalalib/test/src/mill/javalib/junit5/JUnit5Tests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package mill.javalib.junit5

import mill.scalalib.JavaModule
import mill.scalalib.TestModule
import mill.testkit.{TestBaseModule, UnitTester}
import utest._

object JUnit5Tests extends TestSuite {

object module extends TestBaseModule with JavaModule {
object test extends JavaTests with TestModule.Junit5
}

val testModuleSourcesPath = os.Path(sys.env("MILL_TEST_RESOURCE_DIR")) / "junit5"

def tests = Tests {
test("discovery") {
val eval = UnitTester(module, testModuleSourcesPath)
val res = eval(module.test.discoveredTestClasses)
assert(res.isRight)
assert(res.toOption.get.value == Seq("qux.QuxTests"))
}
test("execution") {
val eval = UnitTester(module, testModuleSourcesPath)
val res = eval(module.test.test(""))
assert(res.isRight)
assert(res.toOption.get.value._2.forall(_.fullyQualifiedName == "qux.QuxTests"))
}
}
}

0 comments on commit 15e4019

Please sign in to comment.