Skip to content

Commit

Permalink
Use Worker API to download files in parallel
Browse files Browse the repository at this point in the history
Refers to #138
  • Loading branch information
michel-kraemer committed Jan 16, 2022
1 parent add4f2a commit 0f07235
Show file tree
Hide file tree
Showing 16 changed files with 317 additions and 52 deletions.
12 changes: 6 additions & 6 deletions src/main/java/de/undercouch/gradle/tasks/download/Download.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ public Download() {
*/
@TaskAction
public void download() throws IOException {
action.execute();

// handle 'upToDate'
if (action.isUpToDate()) {
getState().setDidWork(false);
}
action.execute().thenRun(() -> {
// handle 'upToDate'
if (action.isUpToDate()) {
getState().setDidWork(false);
}
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import de.undercouch.gradle.tasks.download.internal.CachingHttpClientFactory;
import de.undercouch.gradle.tasks.download.internal.HttpClientFactory;
import de.undercouch.gradle.tasks.download.internal.ProgressLoggerWrapper;
import de.undercouch.gradle.tasks.download.internal.WorkerExecutorHelper;
import groovy.json.JsonOutput;
import groovy.json.JsonSlurper;
import groovy.lang.Closure;
Expand Down Expand Up @@ -44,10 +45,12 @@
import org.gradle.api.JavaVersion;
import org.gradle.api.Project;
import org.gradle.api.Task;
import org.gradle.api.UncheckedIOException;
import org.gradle.api.file.Directory;
import org.gradle.api.file.ProjectLayout;
import org.gradle.api.file.RegularFile;
import org.gradle.api.logging.Logger;
import org.gradle.api.model.ObjectFactory;
import org.gradle.api.provider.Provider;
import org.gradle.util.GradleVersion;

Expand All @@ -71,6 +74,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -92,6 +96,7 @@ public class DownloadAction implements DownloadSpec {
private final ProjectLayout projectLayout;
private final Logger logger;
private final Object servicesOwner;
private final ObjectFactory objectFactory;
private final boolean isOffline;
private final List<Object> sourceObjects = new ArrayList<>(1);
private List<URL> cachedSources;
Expand Down Expand Up @@ -139,15 +144,18 @@ public DownloadAction(Project project, @Nullable Task task) {
} else {
this.servicesOwner = project;
}
this.objectFactory = project.getObjects();
this.isOffline = project.getGradle().getStartParameter().isOffline();
this.downloadTaskDir = new File(project.getBuildDir(), "download-task");
}

/**
* Starts downloading
* @return a {@link CompletableFuture} that completes once the download
* has finished
* @throws IOException if the file could not downloaded
*/
public void execute() throws IOException {
public CompletableFuture<Void> execute() throws IOException {
if (GradleVersion.current().compareTo(HARD_MIN_GRADLE_VERSION) < 0 && !quiet) {
throw new IllegalStateException("gradle-download-task requires " +
"Gradle 5.x or higher");
Expand Down Expand Up @@ -181,7 +189,7 @@ public void execute() throws IOException {
//make sure build dir exists
dest.mkdirs();
}

if (sources.size() > 1 && !dest.isDirectory()) {
if (!dest.exists()) {
// create directory automatically
Expand All @@ -191,18 +199,58 @@ public void execute() throws IOException {
+ "the destination has to be a directory.");
}
}


WorkerExecutorHelper workerExecutor = WorkerExecutorHelper.newInstance(objectFactory);

CachingHttpClientFactory clientFactory = new CachingHttpClientFactory();
try {
for (URL src : sources) {
execute(src, clientFactory);
CompletableFuture<?>[] futures = new CompletableFuture[sources.size()];
for (int i = 0; i < sources.size(); i++) {
URL src = sources.get(i);

// create progress logger
ProgressLoggerWrapper progressLogger = new ProgressLoggerWrapper(logger);
if (!quiet) {
try {
progressLogger.init(servicesOwner, src.toString());
} catch (Exception e) {
// unable to get progress logger
logger.error("Unable to get progress logger. Download "
+ "progress will not be displayed.");
}
}
} finally {
clientFactory.close();

// submit download job for asynchronous execution
CompletableFuture<Void> f = new CompletableFuture<>();
futures[i] = f;
workerExecutor.submit(() -> {
try {
execute(src, clientFactory, progressLogger);
f.complete(null);
} catch (Throwable t) {
f.completeExceptionally(t);
throw t;
}
});
}

// wait for all downloads to finish (necessary if we're on an old
// Gradle version (< 5.6) without Worker API)
if (workerExecutor.needsAwait()) {
workerExecutor.await();
}

return CompletableFuture.allOf(futures).whenComplete((v, t) -> {
// always close HTTP client factory
try {
clientFactory.close();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});
}

private void execute(URL src, HttpClientFactory clientFactory) throws IOException {
private void execute(URL src, HttpClientFactory clientFactory,
ProgressLoggerWrapper progressLogger) throws IOException {
final File destFile = makeDestFile(src);
if (!overwrite && destFile.exists()) {
if (!quiet) {
Expand All @@ -229,18 +277,6 @@ private void execute(URL src, HttpClientFactory clientFactory) throws IOExceptio

final long timestamp = onlyIfModified && destFile.exists() ? destFile.lastModified() : 0;

// create progress logger
ProgressLoggerWrapper progressLogger = new ProgressLoggerWrapper(logger);
if (!quiet) {
try {
progressLogger.init(servicesOwner, src.toString());
} catch (Exception e) {
// unable to get progress logger
logger.error("Unable to get progress logger. Download "
+ "progress will not be displayed.");
}
}

if ("file".equals(src.getProtocol())) {
executeFileProtocol(src, timestamp, destFile, progressLogger);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public DownloadExtension(Project project) {
public DownloadExtension configure(Closure cl) {
DownloadAction da = ConfigureUtil.configure(cl, new DownloadAction(project));
try {
da.execute();
} catch (IOException e) {
da.execute().get();
} catch (Exception e) {
String message = e.getMessage();
if (message == null) {
message = "Could not download file";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package de.undercouch.gradle.tasks.download.internal;

import org.gradle.api.UncheckedIOException;
import org.gradle.api.provider.Property;
import org.gradle.workers.WorkAction;
import org.gradle.workers.WorkParameters;
import org.gradle.workers.WorkQueue;
import org.gradle.workers.WorkerExecutor;

import javax.inject.Inject;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
* Default implementation of {@link WorkerExecutorHelper} that executes
* {@link Job}s asynchronously using the Gradle Worker API
* @author Michel Kraemer
*/
@SuppressWarnings("UnstableApiUsage")
public class DefaultWorkerExecutorHelper extends WorkerExecutorHelper {
/**
* A unique ID for jobs. Used to access jobs in {@link #jobs}
*/
private static final AtomicInteger UNIQUE_ID = new AtomicInteger();

/**
* A maps of jobs submitted to this executor
*/
private static final Map<Integer, Job> jobs = new ConcurrentHashMap<>();

private final WorkerExecutor workerExecutor;
private final WorkQueue workQueue;

/**
* Constructs a new executor
* @param workerExecutor the Gradle Worker API executor
*/
@Inject
public DefaultWorkerExecutorHelper(WorkerExecutor workerExecutor) {
this.workerExecutor = workerExecutor;
this.workQueue = workerExecutor.noIsolation();
}

@Override
public void submit(Job job) {
int id = UNIQUE_ID.getAndIncrement();
jobs.put(id, job);
workQueue.submit(DefaultWorkAction.class, parameters ->
parameters.getID().set(id));
}

@Override
public void await() {
workerExecutor.await();
}

@Override
public boolean needsAwait() {
return false;
}

public interface DefaultWorkParameters extends WorkParameters {
Property<Integer> getID();
}

public static abstract class DefaultWorkAction implements WorkAction<DefaultWorkParameters> {
@Override
public void execute() {
Job job = jobs.remove(getParameters().getID().get());
try {
job.run();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package de.undercouch.gradle.tasks.download.internal;

import java.io.IOException;

/**
* An asynchronous job executed by a {@link WorkerExecutorHelper}
* @author Michel Kraemer
*/
public interface Job {
/**
* Execute the job
* @throws IOException if the job failed
*/
void run() throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package de.undercouch.gradle.tasks.download.internal;

import java.io.IOException;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

/**
* Executes jobs asynchronously with an {@link ExecutorService} on Gradle
* versions where the Worker API is not available
* @author Michel Kraemer
*/
public class LegacyWorkerExecutorHelper extends WorkerExecutorHelper {
private final ExecutorService executorService = Executors.newWorkStealingPool();
private final Queue<Future<Void>> futures = new ConcurrentLinkedQueue<>();

@Override
public void submit(Job job) {
CompletableFuture<Void> f = new CompletableFuture<>();
futures.add(f);
executorService.submit(() -> {
try {
job.run();
f.complete(null);
futures.remove(f);
} catch (IOException e) {
f.completeExceptionally(e);
}
});
}

@Override
public void await() {
Future<Void> f;
while ((f = futures.poll()) != null) {
try {
f.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

@Override
public boolean needsAwait() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package de.undercouch.gradle.tasks.download.internal;

import org.gradle.api.model.ObjectFactory;
import org.gradle.util.GradleVersion;

/**
* Executes jobs asynchronously. Either uses the Gradle Worker API (if
* available) or falls back to a legacy implementation using an
* {@link java.util.concurrent.ExecutorService}.
* @author Michel Kraemer
*/
public abstract class WorkerExecutorHelper {
/**
* Creates a new instance of the {@link WorkerExecutorHelper} depending
* on the Gradle version
* @param objectFactory creates Gradle model objects
* @return the helper
*/
public static WorkerExecutorHelper newInstance(ObjectFactory objectFactory) {
if (GradleVersion.current().getBaseVersion().compareTo(GradleVersion.version("5.6")) >= 0) {
return objectFactory.newInstance(DefaultWorkerExecutorHelper.class);
}
return new LegacyWorkerExecutorHelper();
}

/**
* Execute a job asynchronously
* @param job the job to execute
*/
public abstract void submit(Job job);

/**
* Wait for all jobs of the current build operation to complete
*/
public abstract void await();

/**
* Returns {@code true} if {@link #await()} MUST be called at the end of
* the task. This mostly applies to Gradle versions that don't have a
* Worker API and therefore cannot let the task continue to run in parallel
* to others.
* @return {@code true} if {@link #await()} must be called
*/
public abstract boolean needsAwait();
}
Loading

0 comments on commit 0f07235

Please sign in to comment.