Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support some level of parallelization in junit-vintage-engine #4135

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,20 @@
import static org.junit.platform.engine.TestExecutionResult.successful;
import static org.junit.vintage.engine.descriptor.VintageTestDescriptor.ENGINE_ID;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.apiguardian.api.API;
import org.junit.platform.commons.logging.Logger;
import org.junit.platform.commons.logging.LoggerFactory;
import org.junit.platform.commons.util.ExceptionUtils;
import org.junit.platform.engine.EngineDiscoveryRequest;
import org.junit.platform.engine.EngineExecutionListener;
import org.junit.platform.engine.ExecutionRequest;
Expand All @@ -37,6 +47,13 @@
@API(status = INTERNAL, since = "4.12")
public final class VintageTestEngine implements TestEngine {

private static final Logger logger = LoggerFactory.getLogger(VintageTestEngine.class);

private static final int DEFAULT_THREAD_POOL_SIZE = Runtime.getRuntime().availableProcessors();
private static final int SHUTDOWN_TIMEOUT_SECONDS = 30;
private static final String PARALLEL_EXECUTION_ENABLED = "junit.vintage.execution.parallel.enabled";
private static final String PARALLEL_POOL_SIZE = "junit.vintage.execution.parallel.pool-size";

@Override
public String getId() {
return ENGINE_ID;
Expand Down Expand Up @@ -69,11 +86,73 @@ public void execute(ExecutionRequest request) {
EngineExecutionListener engineExecutionListener = request.getEngineExecutionListener();
VintageEngineDescriptor engineDescriptor = (VintageEngineDescriptor) request.getRootTestDescriptor();
engineExecutionListener.executionStarted(engineDescriptor);
executeAllChildren(engineDescriptor, engineExecutionListener);
executeAllChildren(engineDescriptor, engineExecutionListener, request);
engineExecutionListener.executionFinished(engineDescriptor, successful());
}

private void executeAllChildren(VintageEngineDescriptor engineDescriptor,
EngineExecutionListener engineExecutionListener, ExecutionRequest request) {
boolean parallelExecutionEnabled = getParallelExecutionEnabled(request);

if (parallelExecutionEnabled) {
if (executeInParallel(engineDescriptor, engineExecutionListener, request)) {
Thread.currentThread().interrupt();
}
}
else {
executeSequentially(engineDescriptor, engineExecutionListener);
}
}

private boolean executeInParallel(VintageEngineDescriptor engineDescriptor,
EngineExecutionListener engineExecutionListener, ExecutionRequest request) {
ExecutorService executorService = Executors.newFixedThreadPool(getThreadPoolSize(request));
RunnerExecutor runnerExecutor = new RunnerExecutor(engineExecutionListener);

List<CompletableFuture<Void>> futures = new ArrayList<>();
for (Iterator<TestDescriptor> iterator = engineDescriptor.getModifiableChildren().iterator(); iterator.hasNext();) {
TestDescriptor descriptor = iterator.next();
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
runnerExecutor.execute((RunnerTestDescriptor) descriptor);
}, executorService);

futures.add(future);
iterator.remove();
}

CompletableFuture<Void> allOf = CompletableFuture.allOf(futures.toArray(new CompletableFuture<?>[0]));
boolean wasInterrupted = false;
try {
allOf.get();
}
catch (InterruptedException e) {
logger.warn(e, () -> "Interruption while waiting for parallel test execution to finish");
wasInterrupted = true;
}
catch (ExecutionException e) {
throw ExceptionUtils.throwAsUncheckedException(e.getCause());
}
finally {
shutdownExecutorService(executorService);
}
return wasInterrupted;
}

private void shutdownExecutorService(ExecutorService executorService) {
try {
executorService.shutdown();
marcphilipp marked this conversation as resolved.
Show resolved Hide resolved
if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
logger.warn(() -> "Executor service did not terminate within the specified timeout");
executorService.shutdownNow();
}
}
catch (InterruptedException e) {
logger.warn(e, () -> "Interruption while waiting for executor service to shut down");
marcphilipp marked this conversation as resolved.
Show resolved Hide resolved
Thread.currentThread().interrupt();
}
}

private void executeSequentially(VintageEngineDescriptor engineDescriptor,
EngineExecutionListener engineExecutionListener) {
RunnerExecutor runnerExecutor = new RunnerExecutor(engineExecutionListener);
for (Iterator<TestDescriptor> iterator = engineDescriptor.getModifiableChildren().iterator(); iterator.hasNext();) {
Expand All @@ -82,4 +161,21 @@ private void executeAllChildren(VintageEngineDescriptor engineDescriptor,
}
}

private boolean getParallelExecutionEnabled(ExecutionRequest request) {
return request.getConfigurationParameters().getBoolean(PARALLEL_EXECUTION_ENABLED).orElse(false);
}

private int getThreadPoolSize(ExecutionRequest request) {
Optional<String> poolSize = request.getConfigurationParameters().get(PARALLEL_POOL_SIZE);
if (poolSize.isPresent()) {
try {
return Integer.parseInt(poolSize.get());
}
catch (NumberFormatException e) {
logger.warn(() -> "Invalid value for parallel pool size: " + poolSize.get());
}
}
return DEFAULT_THREAD_POOL_SIZE;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* All rights reserved. This program and the accompanying materials are
* made available under the terms of the Eclipse Public License v2.0 which
* accompanies this distribution and is available at
*
* https://www.eclipse.org/legal/epl-v20.html
*/

package org.junit.vintage.engine.execution;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.platform.testkit.engine.EventConditions.event;
import static org.junit.platform.testkit.engine.EventConditions.finishedSuccessfully;
import static org.junit.platform.testkit.engine.EventConditions.finishedWithFailure;
import static org.junit.platform.testkit.engine.EventConditions.started;
import static org.junit.platform.testkit.engine.EventConditions.test;
import static org.junit.vintage.engine.samples.junit4.JUnit4ParallelTestCase.AtomicOperationParallelTestCase;
import static org.junit.vintage.engine.samples.junit4.JUnit4ParallelTestCase.ConcurrentFailureTestCase;
import static org.junit.vintage.engine.samples.junit4.JUnit4ParallelTestCase.ConcurrentIncrementTestCase;
import static org.junit.vintage.engine.samples.junit4.JUnit4ParallelTestCase.FailingParallelTestCase;
import static org.junit.vintage.engine.samples.junit4.JUnit4ParallelTestCase.ParallelFailingTestCase;
import static org.junit.vintage.engine.samples.junit4.JUnit4ParallelTestCase.SuccessfulParallelTestCase;

import java.time.Instant;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.assertj.core.api.Condition;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestReporter;
import org.junit.platform.engine.discovery.ClassSelector;
import org.junit.platform.engine.discovery.DiscoverySelectors;
import org.junit.platform.launcher.LauncherDiscoveryRequest;
import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder;
import org.junit.platform.testkit.engine.EngineExecutionResults;
import org.junit.platform.testkit.engine.EngineTestKit;
import org.junit.platform.testkit.engine.Event;
import org.junit.platform.testkit.engine.Events;
import org.junit.vintage.engine.VintageTestEngine;

class ParallelExecutionIntegrationTests {

private static final String PARALLEL_EXECUTION_ENABLED = "junit.vintage.execution.parallel.enabled";
private static final String PARALLEL_POOL_SIZE = "junit.vintage.execution.parallel.pool-size";

@Test
void successfulParallelTest(TestReporter reporter) {
var events = executeInParallelSuccessfully(3, SuccessfulParallelTestCase.class,
ConcurrentIncrementTestCase.class, AtomicOperationParallelTestCase.class).list();

var startedTimestamps = getTimestampsFor(events, event(test(), started()));
var finishedTimestamps = getTimestampsFor(events, event(test(), finishedSuccessfully()));
var threadNames = new HashSet<>(Set.of(SuccessfulParallelTestCase.threadNames,
ConcurrentIncrementTestCase.threadNames, AtomicOperationParallelTestCase.threadNames));

reporter.publishEntry("startedTimestamps", startedTimestamps.toString());
reporter.publishEntry("finishedTimestamps", finishedTimestamps.toString());

assertThat(startedTimestamps).hasSize(9);
assertThat(finishedTimestamps).hasSize(9);
assertThat(threadNames).hasSize(3);
}

@Test
void failingParallelTest(TestReporter reporter) {
var events = executeInParallel(3, FailingParallelTestCase.class, ConcurrentFailureTestCase.class,
ParallelFailingTestCase.class).list();

var startedTimestamps = getTimestampsFor(events, event(test(), started()));
var finishedTimestamps = getTimestampsFor(events, event(test(), finishedWithFailure()));
var threadNames = new HashSet<>(Set.of(FailingParallelTestCase.threadNames,
ConcurrentFailureTestCase.threadNames, ParallelFailingTestCase.threadNames));

reporter.publishEntry("startedTimestamps", startedTimestamps.toString());
reporter.publishEntry("finishedTimestamps", finishedTimestamps.toString());

assertThat(startedTimestamps).hasSize(9);
assertThat(finishedTimestamps).hasSize(9);
assertThat(threadNames).hasSize(3);
}

private List<Instant> getTimestampsFor(List<Event> events, Condition<Event> condition) {
// @formatter:off
return events.stream()
.filter(condition::matches)
.map(Event::getTimestamp)
.toList();
// @formatter:on
}

private Events executeInParallelSuccessfully(int poolSize, Class<?>... testClasses) {
var events = execute(poolSize, testClasses).allEvents();
try {
return events.assertStatistics(it -> it.failed(0));
}
catch (AssertionError error) {
events.debug();
throw error;
}
}

private Events executeInParallel(int poolSize, Class<?>... testClasses) {
return execute(poolSize, testClasses).allEvents();
}

private static EngineExecutionResults execute(int poolSize, Class<?>... testClass) {
return EngineTestKit.execute(new VintageTestEngine(), request(poolSize, testClass));
}

private static LauncherDiscoveryRequest request(int poolSize, Class<?>... testClasses) {
var classSelectors = Arrays.stream(testClasses) //
.map(DiscoverySelectors::selectClass) //
.toArray(ClassSelector[]::new);

return LauncherDiscoveryRequestBuilder.request().selectors(classSelectors).configurationParameter(
PARALLEL_EXECUTION_ENABLED, String.valueOf(true)).configurationParameter(PARALLEL_POOL_SIZE,
String.valueOf(poolSize)).build();
}

}
Loading
Loading