diff --git a/junit-vintage-engine/src/main/java/org/junit/vintage/engine/VintageTestEngine.java b/junit-vintage-engine/src/main/java/org/junit/vintage/engine/VintageTestEngine.java index be13d89580a7..56181b8b3b1b 100644 --- a/junit-vintage-engine/src/main/java/org/junit/vintage/engine/VintageTestEngine.java +++ b/junit-vintage-engine/src/main/java/org/junit/vintage/engine/VintageTestEngine.java @@ -15,9 +15,17 @@ import static org.junit.vintage.engine.descriptor.VintageTestDescriptor.ENGINE_ID; import java.util.Iterator; +import java.util.List; import java.util.Optional; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.apiguardian.api.API; +import org.junit.platform.commons.logging.Logger; +import org.junit.platform.commons.logging.LoggerFactory; import org.junit.platform.engine.EngineDiscoveryRequest; import org.junit.platform.engine.EngineExecutionListener; import org.junit.platform.engine.ExecutionRequest; @@ -37,6 +45,11 @@ @API(status = INTERNAL, since = "4.12") public final class VintageTestEngine implements TestEngine { + private static final Logger logger = LoggerFactory.getLogger(VintageTestEngine.class); + + private static final int DEFAULT_THREAD_POOL_SIZE = 10; + private static final int SHUTDOWN_TIMEOUT_SECONDS = 60; + @Override public String getId() { return ENGINE_ID; @@ -75,11 +88,71 @@ public void execute(ExecutionRequest request) { private void executeAllChildren(VintageEngineDescriptor engineDescriptor, EngineExecutionListener engineExecutionListener) { + boolean parallelExecutionEnabled = getParallelExecutionEnabled(); + + if (parallelExecutionEnabled) { + executeInParallel(engineDescriptor, engineExecutionListener); + } + else { + executeSequentially(engineDescriptor, engineExecutionListener); + } + } + + private void executeInParallel(VintageEngineDescriptor engineDescriptor, + EngineExecutionListener engineExecutionListener) { + int taskCount = engineDescriptor.getChildren().size(); + CountDownLatch latch = new CountDownLatch(taskCount); + + ExecutorService executorService = Executors.newWorkStealingPool(getThreadPoolSize()); + RunnerExecutor runnerExecutor = new RunnerExecutor(engineExecutionListener); + + List children = new CopyOnWriteArrayList<>(engineDescriptor.getModifiableChildren()); + for (TestDescriptor descriptor : children) { + RunnerTestDescriptor testDescriptor = (RunnerTestDescriptor) descriptor; + executorService.submit(() -> { + try { + runnerExecutor.execute(testDescriptor); + } + catch (Exception e) { + engineExecutionListener.executionSkipped(testDescriptor, e.getMessage()); + } + finally { + latch.countDown(); + } + }); + } + try { + if (!latch.await(SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + logger.warn(() -> "Timeout while waiting for parallel test execution to finish"); + } + } + catch (InterruptedException e) { + logger.warn(e, () -> "Interrupted while awaiting executor service termination"); + Thread.currentThread().interrupt(); + } + finally { + executorService.shutdown(); + } + } + + private void executeSequentially(VintageEngineDescriptor engineDescriptor, + EngineExecutionListener engineExecutionListener) { RunnerExecutor runnerExecutor = new RunnerExecutor(engineExecutionListener); - for (Iterator iterator = engineDescriptor.getModifiableChildren().iterator(); iterator.hasNext();) { + Iterator iterator = engineDescriptor.getModifiableChildren().iterator(); + while (iterator.hasNext()) { runnerExecutor.execute((RunnerTestDescriptor) iterator.next()); iterator.remove(); } } + private boolean getParallelExecutionEnabled() { + // get parallel execution enabled from configuration + return true; + } + + private int getThreadPoolSize() { + // get thread pool size from configuration + return DEFAULT_THREAD_POOL_SIZE; + } + }