Skip to content

Commit

Permalink
feat(runner): handle output dir + labels + better use of core script …
Browse files Browse the repository at this point in the history
…runners' steps (#409)
  • Loading branch information
brian-mulier-p authored Apr 5, 2024
1 parent 9e9c3ea commit a67c9e2
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 155 deletions.
176 changes: 91 additions & 85 deletions src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.*;
import software.amazon.awssdk.transfer.s3.S3TransferManager;
import software.amazon.awssdk.transfer.s3.model.DownloadFileRequest;
import software.amazon.awssdk.transfer.s3.model.FileDownload;
import software.amazon.awssdk.transfer.s3.model.FileUpload;
import software.amazon.awssdk.transfer.s3.model.UploadFileRequest;
import software.amazon.awssdk.transfer.s3.model.*;

import java.net.URI;
import java.nio.file.Path;
import java.time.Duration;
import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -71,7 +71,7 @@ Upon worker restart, this job will be requeued and executed again. Moreover, the
- SUBMITTED: 6
- OTHER: -1""")
@Plugin(examples = {}, beta = true)
public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, AbstractConnectionInterface {
public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, AbstractConnectionInterface, RemoteRunnerInterface {
private static final Map<JobStatus, Integer> exitCodeByStatus = Map.of(
JobStatus.FAILED, 1,
JobStatus.RUNNING, 2,
Expand All @@ -81,8 +81,6 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab
JobStatus.SUBMITTED, 6,
JobStatus.UNKNOWN_TO_SDK_VERSION, -1
);
public static final String S3_WORKING_DIR_KEY = "s3WorkingDir";
public static final String WORKING_DIR_KEY = "workingDir";

@NotNull
private String region;
Expand Down Expand Up @@ -119,7 +117,7 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab
description = "It's mandatory to provide a bucket if you want to use such properties."
)
@PluginProperty(dynamic = true)
private String s3Bucket;
private String bucket;

@Schema(
title = "Execution role to use to run the job.",
Expand Down Expand Up @@ -154,28 +152,13 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab
title = "The maximum duration to wait for the job completion. AWS Batch will automatically timeout the job upon reaching such duration and the task will be failed."
)
@Builder.Default
private final Duration waitUntilCompletion = Duration.ofHours(1);
private Duration waitUntilCompletion = Duration.ofHours(1);

@Override
public RunnerResult run(RunContext runContext, ScriptCommands commands, List<String> filesToUploadWithoutInternalStorage, List<String> filesToDownload) throws Exception {
boolean hasS3Bucket = this.s3Bucket != null;


String renderedBucket = runContext.render(s3Bucket);
String workingDirName = IdUtils.create();
Map<String, Object> additionalVars = commands.getAdditionalVars();
Optional.ofNullable(renderedBucket).ifPresent(bucket -> additionalVars.putAll(Map.<String, Object>of(
S3_WORKING_DIR_KEY, "s3://" + bucket + "/" + workingDirName,
WORKING_DIR_KEY, "/" + workingDirName,
"outputDir", "/" + workingDirName
)));

List<String> filesToUpload = new ArrayList<>(ListUtils.emptyOnNull(filesToUploadWithoutInternalStorage));
List<String> command = ScriptService.uploadInputFiles(
runContext,
runContext.render(commands.getCommands(), additionalVars),
(ignored, localFilePath) -> filesToUpload.add(localFilePath)
);
public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, List<String> filesToUpload, List<String> filesToDownload) throws Exception {
boolean hasS3Bucket = this.bucket != null;

String renderedBucket = runContext.render(bucket);

boolean hasFilesToUpload = !ListUtils.isEmpty(filesToUpload);
if (hasFilesToUpload && !hasS3Bucket) {
Expand All @@ -187,7 +170,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List<Str
}

Logger logger = runContext.logger();
AbstractLogConsumer logConsumer = commands.getLogConsumer();
AbstractLogConsumer logConsumer = scriptCommands.getLogConsumer();

String renderedRegion = runContext.render(this.region);
Region regionObject = Region.of(renderedRegion);
Expand All @@ -204,7 +187,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List<Str
BatchClient client = batchClientBuilder
.build();

String jobName = IdUtils.create();
String jobName = ScriptService.jobName(runContext);

ComputeEnvironmentDetail computeEnvironmentDetail = client.describeComputeEnvironments(
DescribeComputeEnvironmentsRequest.builder()
Expand Down Expand Up @@ -234,7 +217,9 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List<Str
RegisterJobDefinitionRequest.Builder jobDefBuilder = RegisterJobDefinitionRequest.builder()
.jobDefinitionName(jobDefinitionName)
.type(JobDefinitionType.CONTAINER)
.tags(ScriptService.labels(runContext, "kestra-", true, true))
.platformCapabilities(platformCapability);
Path batchWorkingDirectory = (Path) this.additionalVars(runContext, scriptCommands).get(ScriptService.VAR_WORKING_DIR);

if (hasFilesToUpload) {
try (S3TransferManager transferManager = transferManager(runContext)) {
Expand All @@ -245,7 +230,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List<Str
.builder()
.bucket(renderedBucket)
// Use path to eventually deduplicate leading '/'
.key(workingDirName + Path.of("/" + relativePath))
.key((batchWorkingDirectory + Path.of("/" + relativePath).toString()).substring(1))
.build()
)
.source(runContext.resolve(Path.of(relativePath)))
Expand Down Expand Up @@ -287,42 +272,43 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List<Str

int baseSideContainerMemory = 128;
float baseSideContainerCpu = 0.1f;
Object s3WorkingDir = additionalVars.get(S3_WORKING_DIR_KEY);
if (hasFilesToUpload) {
Map<String, Object> additionalVars = this.additionalVars(runContext, scriptCommands);
Object s3WorkingDir = additionalVars.get(ScriptService.VAR_BUCKET_PATH);
Path batchOutputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR);
MountPoint volumeMount = MountPoint.builder()
.containerPath(batchWorkingDirectory.toString())
.sourceVolume(kestraVolume)
.build();

if (hasS3Bucket) {
containers.add(
withResources(
TaskContainerProperties.builder()
.image("ghcr.io/kestra-io/awsbatch:latest")
.mountPoints(
MountPoint.builder()
.containerPath("/" + workingDirName)
.sourceVolume(kestraVolume)
.build()
)
.mountPoints(volumeMount)
.essential(false)
.command(ScriptService.scriptCommands(
List.of("/bin/sh", "-c"),
null,
filesToUpload.stream()
.map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " /" + workingDirName + Path.of("/" + relativePath))
.toList()
Stream.concat(
ListUtils.emptyOnNull(filesToUpload).stream()
.map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath)),
Stream.of("mkdir " + batchOutputDirectory)
).toList()
))
.name(inputFilesContainerName),
baseSideContainerMemory,
baseSideContainerCpu).build()
);
}

int sideContainersMemoryAllocations = (hasFilesToUpload ? baseSideContainerMemory : 0) + (hasFilesToDownload ? baseSideContainerMemory : 0);
float sideContainersCpuAllocations = (hasFilesToUpload ? baseSideContainerCpu : 0) + (hasFilesToDownload ? baseSideContainerCpu : 0);

Map<String, String> environment = Optional.ofNullable(commands.getEnv()).orElse(new HashMap<>());
environment.put("WORKING_DIR", "/" + workingDirName);
int sideContainersMemoryAllocations = hasS3Bucket ? baseSideContainerMemory * 2 : 0;
float sideContainersCpuAllocations = hasS3Bucket ? baseSideContainerCpu * 2 : 0;

TaskContainerProperties.Builder mainContainerBuilder = withResources(
TaskContainerProperties.builder()
.image(commands.getContainerImage())
.command(command)
.image(scriptCommands.getContainerImage())
.command(scriptCommands.getCommands())
.name(mainContainerName)
.logConfiguration(
LogConfiguration.builder()
Expand All @@ -331,47 +317,36 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List<Str
.build()
)
.environment(
environment.entrySet().stream()
this.env(runContext, scriptCommands).entrySet().stream()
.map(e -> KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build())
.toArray(KeyValuePair[]::new)
)
.essential(!hasFilesToDownload),
.essential(!hasS3Bucket),
Integer.parseInt(resources.getRequest().getMemory()) - sideContainersMemoryAllocations,
Float.parseFloat(resources.getRequest().getCpu()) - sideContainersCpuAllocations
);

if (hasFilesToUpload) {
if (hasS3Bucket) {
mainContainerBuilder.dependsOn(TaskContainerDependency.builder().containerName(inputFilesContainerName).condition("SUCCESS").build());
}

if (hasFilesToUpload || hasFilesToDownload) {
mainContainerBuilder.mountPoints(
MountPoint.builder()
.containerPath("/" + workingDirName)
.sourceVolume(kestraVolume)
.build()
);
mainContainerBuilder.mountPoints(volumeMount);
}

containers.add(mainContainerBuilder.build());

if (hasFilesToDownload) {
if (hasS3Bucket) {
containers.add(
withResources(
TaskContainerProperties.builder()
.image("ghcr.io/kestra-io/awsbatch:latest")
.mountPoints(
MountPoint.builder()
.containerPath("/" + workingDirName)
.sourceVolume(kestraVolume)
.build()
)
.mountPoints(volumeMount)
.command(ScriptService.scriptCommands(
List.of("/bin/sh", "-c"),
null,
filesToDownload.stream()
.map(relativePath -> "aws s3 cp /" + workingDirName + "/" + relativePath + " " + s3WorkingDir + "/" + relativePath)
.toList()
Stream.concat(
filesToDownload.stream()
.map(relativePath -> "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)),
Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive")
).toList()
))
.dependsOn(TaskContainerDependency.builder().containerName(mainContainerName).condition("SUCCESS").build())
.name(outputFilesContainerName),
Expand Down Expand Up @@ -480,32 +455,57 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List<Str
throw new ScriptException(exitCode, logConsumer.getStdOutCount(), logConsumer.getStdErrCount());
}

if (hasFilesToDownload) {
if (hasS3Bucket) {
try (S3TransferManager transferManager = transferManager(runContext)) {
filesToDownload.stream().map(relativePath -> transferManager.downloadFile(
DownloadFileRequest.builder()
.getObjectRequest(GetObjectRequest.builder()
.bucket(renderedBucket)
.key(workingDirName + "/" + relativePath)
.key((batchWorkingDirectory + "/" + relativePath).substring(1))
.build())
.destination(runContext.resolve(Path.of(relativePath)))
.destination(scriptCommands.getWorkingDirectory().resolve(Path.of(relativePath.startsWith("/") ? relativePath.substring(1) : relativePath)))
.build()
)).map(FileDownload::completionFuture)
.forEach(throwConsumer(CompletableFuture::get));

transferManager.downloadDirectory(DownloadDirectoryRequest.builder()
.bucket(renderedBucket)
.destination(scriptCommands.getOutputDirectory())
.listObjectsV2RequestTransformer(builder -> builder
.prefix(batchOutputDirectory.toString().substring(1))
)
.build())
.completionFuture()
.get();
}
}
} finally {
cleanupBatchResources(client, jobQueue, jobDefArn);
// Manual close after cleanup to make sure we get all remaining logs
cloudWatchLogsAsyncClient.close();
if (hasFilesToUpload || hasFilesToDownload) {
cleanupS3Resources(runContext, filesToUpload, filesToDownload, workingDirName);
if (hasS3Bucket) {
cleanupS3Resources(runContext, batchWorkingDirectory);
}
}

return new RunnerResult(0, logConsumer);
}

@Override
public Map<String, Object> runnerAdditionalVars(RunContext runContext, ScriptCommands scriptCommands) throws IllegalVariableEvaluationException {
Map<String, Object> additionalVars = new HashMap<>();
Path batchWorkingDirectory = Path.of("/" + IdUtils.create());
additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory);

if (this.bucket != null) {
Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create());
additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory);
additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory);
}

return additionalVars;
}

@Nullable
private static String getLogGroupArn(CloudWatchLogsAsyncClient cloudWatchLogsAsyncClient, String logGroupName, String renderedRegion) throws InterruptedException, ExecutionException {
return cloudWatchLogsAsyncClient.describeLogGroups(
Expand Down Expand Up @@ -534,15 +534,21 @@ private TaskContainerProperties.Builder withResources(TaskContainerProperties.Bu
);
}

private void cleanupS3Resources(RunContext runContext, List<String> filesToUpload, List<String> filesToDownload, String workingDirName) throws IllegalVariableEvaluationException {
String renderedBucket = runContext.render(s3Bucket);
private void cleanupS3Resources(RunContext runContext, Path batchWorkingDirectory) throws IllegalVariableEvaluationException {
String renderedBucket = runContext.render(bucket);
try (S3AsyncClient s3AsyncClient = asyncClient(runContext)) {
ListObjectsV2Request listRequest = ListObjectsV2Request.builder()
.bucket(renderedBucket)
.prefix(batchWorkingDirectory.toString())
.build();

ListObjectsV2Response listResponse = s3AsyncClient.listObjectsV2(listRequest).get();
List<ObjectIdentifier> objectsIdentifiers = listResponse.contents().stream()
.map(S3Object::key)
.map(key -> ObjectIdentifier.builder().key(key).build())
.toList();
StreamSupport.stream(Iterables.partition(
Stream.concat(
Optional.ofNullable(filesToUpload).stream().flatMap(Collection::stream),
Optional.ofNullable(filesToDownload).stream().flatMap(Collection::stream)
).map(file -> ObjectIdentifier.builder().key("/" + workingDirName + "/" + file).build())
.toList(),
objectsIdentifiers,
1000
).spliterator(), false)
.map(objects -> s3AsyncClient.deleteObjects(
Expand All @@ -554,7 +560,7 @@ private void cleanupS3Resources(RunContext runContext, List<String> filesToUploa
.build()
)).forEach(throwConsumer(CompletableFuture::get));
} catch (Exception e) {
runContext.logger().warn("Error while cleaning up S3: {}", e.getMessage());
runContext.logger().warn("Error while cleaning up S3: {}", e.toString());
}
}

Expand Down
Loading

0 comments on commit a67c9e2

Please sign in to comment.