From a67c9e291d72659a005e7bc8dd9094ba0a5ae432 Mon Sep 17 00:00:00 2001 From: brian-mulier-p Date: Fri, 5 Apr 2024 15:04:37 +0200 Subject: [PATCH] feat(runner): handle output dir + labels + better use of core script runners' steps (#409) --- .../aws/runner/AwsBatchScriptRunner.java | 176 +++++++++--------- .../aws/runner/AwsBatchScriptRunnerTest.java | 78 +------- 2 files changed, 99 insertions(+), 155 deletions(-) diff --git a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java index 9532186..12ccafc 100644 --- a/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java +++ b/src/main/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunner.java @@ -31,15 +31,15 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.*; import software.amazon.awssdk.transfer.s3.S3TransferManager; -import software.amazon.awssdk.transfer.s3.model.DownloadFileRequest; -import software.amazon.awssdk.transfer.s3.model.FileDownload; -import software.amazon.awssdk.transfer.s3.model.FileUpload; -import software.amazon.awssdk.transfer.s3.model.UploadFileRequest; +import software.amazon.awssdk.transfer.s3.model.*; import java.net.URI; import java.nio.file.Path; import java.time.Duration; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; @@ -71,7 +71,7 @@ Upon worker restart, this job will be requeued and executed again. Moreover, the - SUBMITTED: 6 - OTHER: -1""") @Plugin(examples = {}, beta = true) -public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, AbstractConnectionInterface { +public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, AbstractConnectionInterface, RemoteRunnerInterface { private static final Map exitCodeByStatus = Map.of( JobStatus.FAILED, 1, JobStatus.RUNNING, 2, @@ -81,8 +81,6 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab JobStatus.SUBMITTED, 6, JobStatus.UNKNOWN_TO_SDK_VERSION, -1 ); - public static final String S3_WORKING_DIR_KEY = "s3WorkingDir"; - public static final String WORKING_DIR_KEY = "workingDir"; @NotNull private String region; @@ -119,7 +117,7 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab description = "It's mandatory to provide a bucket if you want to use such properties." ) @PluginProperty(dynamic = true) - private String s3Bucket; + private String bucket; @Schema( title = "Execution role to use to run the job.", @@ -154,28 +152,13 @@ public class AwsBatchScriptRunner extends ScriptRunner implements AbstractS3, Ab title = "The maximum duration to wait for the job completion. AWS Batch will automatically timeout the job upon reaching such duration and the task will be failed." ) @Builder.Default - private final Duration waitUntilCompletion = Duration.ofHours(1); + private Duration waitUntilCompletion = Duration.ofHours(1); @Override - public RunnerResult run(RunContext runContext, ScriptCommands commands, List filesToUploadWithoutInternalStorage, List filesToDownload) throws Exception { - boolean hasS3Bucket = this.s3Bucket != null; - - - String renderedBucket = runContext.render(s3Bucket); - String workingDirName = IdUtils.create(); - Map additionalVars = commands.getAdditionalVars(); - Optional.ofNullable(renderedBucket).ifPresent(bucket -> additionalVars.putAll(Map.of( - S3_WORKING_DIR_KEY, "s3://" + bucket + "/" + workingDirName, - WORKING_DIR_KEY, "/" + workingDirName, - "outputDir", "/" + workingDirName - ))); - - List filesToUpload = new ArrayList<>(ListUtils.emptyOnNull(filesToUploadWithoutInternalStorage)); - List command = ScriptService.uploadInputFiles( - runContext, - runContext.render(commands.getCommands(), additionalVars), - (ignored, localFilePath) -> filesToUpload.add(localFilePath) - ); + public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, List filesToUpload, List filesToDownload) throws Exception { + boolean hasS3Bucket = this.bucket != null; + + String renderedBucket = runContext.render(bucket); boolean hasFilesToUpload = !ListUtils.isEmpty(filesToUpload); if (hasFilesToUpload && !hasS3Bucket) { @@ -187,7 +170,7 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List additionalVars = this.additionalVars(runContext, scriptCommands); + Object s3WorkingDir = additionalVars.get(ScriptService.VAR_BUCKET_PATH); + Path batchOutputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR); + MountPoint volumeMount = MountPoint.builder() + .containerPath(batchWorkingDirectory.toString()) + .sourceVolume(kestraVolume) + .build(); + + if (hasS3Bucket) { containers.add( withResources( TaskContainerProperties.builder() .image("ghcr.io/kestra-io/awsbatch:latest") - .mountPoints( - MountPoint.builder() - .containerPath("/" + workingDirName) - .sourceVolume(kestraVolume) - .build() - ) + .mountPoints(volumeMount) .essential(false) .command(ScriptService.scriptCommands( List.of("/bin/sh", "-c"), null, - filesToUpload.stream() - .map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " /" + workingDirName + Path.of("/" + relativePath)) - .toList() + Stream.concat( + ListUtils.emptyOnNull(filesToUpload).stream() + .map(relativePath -> "aws s3 cp " + s3WorkingDir + Path.of("/" + relativePath) + " " + batchWorkingDirectory + Path.of("/" + relativePath)), + Stream.of("mkdir " + batchOutputDirectory) + ).toList() )) .name(inputFilesContainerName), baseSideContainerMemory, @@ -313,16 +302,13 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List environment = Optional.ofNullable(commands.getEnv()).orElse(new HashMap<>()); - environment.put("WORKING_DIR", "/" + workingDirName); + int sideContainersMemoryAllocations = hasS3Bucket ? baseSideContainerMemory * 2 : 0; + float sideContainersCpuAllocations = hasS3Bucket ? baseSideContainerCpu * 2 : 0; TaskContainerProperties.Builder mainContainerBuilder = withResources( TaskContainerProperties.builder() - .image(commands.getContainerImage()) - .command(command) + .image(scriptCommands.getContainerImage()) + .command(scriptCommands.getCommands()) .name(mainContainerName) .logConfiguration( LogConfiguration.builder() @@ -331,47 +317,36 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List KeyValuePair.builder().name(e.getKey()).value(e.getValue()).build()) .toArray(KeyValuePair[]::new) ) - .essential(!hasFilesToDownload), + .essential(!hasS3Bucket), Integer.parseInt(resources.getRequest().getMemory()) - sideContainersMemoryAllocations, Float.parseFloat(resources.getRequest().getCpu()) - sideContainersCpuAllocations ); - if (hasFilesToUpload) { + if (hasS3Bucket) { mainContainerBuilder.dependsOn(TaskContainerDependency.builder().containerName(inputFilesContainerName).condition("SUCCESS").build()); - } - - if (hasFilesToUpload || hasFilesToDownload) { - mainContainerBuilder.mountPoints( - MountPoint.builder() - .containerPath("/" + workingDirName) - .sourceVolume(kestraVolume) - .build() - ); + mainContainerBuilder.mountPoints(volumeMount); } containers.add(mainContainerBuilder.build()); - if (hasFilesToDownload) { + if (hasS3Bucket) { containers.add( withResources( TaskContainerProperties.builder() .image("ghcr.io/kestra-io/awsbatch:latest") - .mountPoints( - MountPoint.builder() - .containerPath("/" + workingDirName) - .sourceVolume(kestraVolume) - .build() - ) + .mountPoints(volumeMount) .command(ScriptService.scriptCommands( List.of("/bin/sh", "-c"), null, - filesToDownload.stream() - .map(relativePath -> "aws s3 cp /" + workingDirName + "/" + relativePath + " " + s3WorkingDir + "/" + relativePath) - .toList() + Stream.concat( + filesToDownload.stream() + .map(relativePath -> "aws s3 cp " + batchWorkingDirectory + "/" + relativePath + " " + s3WorkingDir + Path.of("/" + relativePath)), + Stream.of("aws s3 cp " + batchOutputDirectory + "/ " + s3WorkingDir + "/" + batchWorkingDirectory.relativize(batchOutputDirectory) + "/ --recursive") + ).toList() )) .dependsOn(TaskContainerDependency.builder().containerName(mainContainerName).condition("SUCCESS").build()) .name(outputFilesContainerName), @@ -480,32 +455,57 @@ public RunnerResult run(RunContext runContext, ScriptCommands commands, List transferManager.downloadFile( DownloadFileRequest.builder() .getObjectRequest(GetObjectRequest.builder() .bucket(renderedBucket) - .key(workingDirName + "/" + relativePath) + .key((batchWorkingDirectory + "/" + relativePath).substring(1)) .build()) - .destination(runContext.resolve(Path.of(relativePath))) + .destination(scriptCommands.getWorkingDirectory().resolve(Path.of(relativePath.startsWith("/") ? relativePath.substring(1) : relativePath))) .build() )).map(FileDownload::completionFuture) .forEach(throwConsumer(CompletableFuture::get)); + + transferManager.downloadDirectory(DownloadDirectoryRequest.builder() + .bucket(renderedBucket) + .destination(scriptCommands.getOutputDirectory()) + .listObjectsV2RequestTransformer(builder -> builder + .prefix(batchOutputDirectory.toString().substring(1)) + ) + .build()) + .completionFuture() + .get(); } } } finally { cleanupBatchResources(client, jobQueue, jobDefArn); // Manual close after cleanup to make sure we get all remaining logs cloudWatchLogsAsyncClient.close(); - if (hasFilesToUpload || hasFilesToDownload) { - cleanupS3Resources(runContext, filesToUpload, filesToDownload, workingDirName); + if (hasS3Bucket) { + cleanupS3Resources(runContext, batchWorkingDirectory); } } return new RunnerResult(0, logConsumer); } + @Override + public Map runnerAdditionalVars(RunContext runContext, ScriptCommands scriptCommands) throws IllegalVariableEvaluationException { + Map additionalVars = new HashMap<>(); + Path batchWorkingDirectory = Path.of("/" + IdUtils.create()); + additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory); + + if (this.bucket != null) { + Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create()); + additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory); + additionalVars.put(ScriptService.VAR_BUCKET_PATH, "s3://" + runContext.render(this.bucket) + batchWorkingDirectory); + } + + return additionalVars; + } + @Nullable private static String getLogGroupArn(CloudWatchLogsAsyncClient cloudWatchLogsAsyncClient, String logGroupName, String renderedRegion) throws InterruptedException, ExecutionException { return cloudWatchLogsAsyncClient.describeLogGroups( @@ -534,15 +534,21 @@ private TaskContainerProperties.Builder withResources(TaskContainerProperties.Bu ); } - private void cleanupS3Resources(RunContext runContext, List filesToUpload, List filesToDownload, String workingDirName) throws IllegalVariableEvaluationException { - String renderedBucket = runContext.render(s3Bucket); + private void cleanupS3Resources(RunContext runContext, Path batchWorkingDirectory) throws IllegalVariableEvaluationException { + String renderedBucket = runContext.render(bucket); try (S3AsyncClient s3AsyncClient = asyncClient(runContext)) { + ListObjectsV2Request listRequest = ListObjectsV2Request.builder() + .bucket(renderedBucket) + .prefix(batchWorkingDirectory.toString()) + .build(); + + ListObjectsV2Response listResponse = s3AsyncClient.listObjectsV2(listRequest).get(); + List objectsIdentifiers = listResponse.contents().stream() + .map(S3Object::key) + .map(key -> ObjectIdentifier.builder().key(key).build()) + .toList(); StreamSupport.stream(Iterables.partition( - Stream.concat( - Optional.ofNullable(filesToUpload).stream().flatMap(Collection::stream), - Optional.ofNullable(filesToDownload).stream().flatMap(Collection::stream) - ).map(file -> ObjectIdentifier.builder().key("/" + workingDirName + "/" + file).build()) - .toList(), + objectsIdentifiers, 1000 ).spliterator(), false) .map(objects -> s3AsyncClient.deleteObjects( @@ -554,7 +560,7 @@ private void cleanupS3Resources(RunContext runContext, List filesToUploa .build() )).forEach(throwConsumer(CompletableFuture::get)); } catch (Exception e) { - runContext.logger().warn("Error while cleaning up S3: {}", e.getMessage()); + runContext.logger().warn("Error while cleaning up S3: {}", e.toString()); } } diff --git a/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java b/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java index d56486a..2140636 100644 --- a/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java +++ b/src/test/java/io/kestra/plugin/aws/runner/AwsBatchScriptRunnerTest.java @@ -1,34 +1,16 @@ package io.kestra.plugin.aws.runner; -import io.kestra.core.models.script.*; -import io.kestra.core.runners.RunContext; -import io.kestra.core.runners.RunContextFactory; -import io.kestra.plugin.scripts.exec.scripts.runners.CommandsWrapper; +import io.kestra.core.models.script.AbstractScriptRunnerTest; +import io.kestra.core.models.script.ScriptRunner; import io.micronaut.context.annotation.Value; import io.micronaut.test.extensions.junit5.annotation.MicronautTest; -import jakarta.inject.Inject; -import org.apache.commons.io.FileUtils; import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import java.io.File; -import java.nio.charset.StandardCharsets; -import java.nio.file.Path; import java.time.Duration; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.is; @MicronautTest @Disabled("Too costly to run on CI") public class AwsBatchScriptRunnerTest extends AbstractScriptRunnerTest { - @Inject - private RunContextFactory runContextFactory; - @Value("${kestra.aws.batch.accessKeyId}") private String accessKeyId; @@ -38,62 +20,13 @@ public class AwsBatchScriptRunnerTest extends AbstractScriptRunnerTest { @Value("${kestra.aws.batch.s3Bucket}") private String s3Bucket; - @Override - @Test - protected void inputAndOutputFiles() throws Exception { - RunContext runContext = runContextFactory.of(Map.of("internalStorageFile", "kestra://some/internalStorage.txt")); - - // Generate input file - Path workingDirectory = runContext.tempDir(); - File file = workingDirectory.resolve("hello.txt").toFile(); - FileUtils.writeStringToFile(file, "Hello World", "UTF-8"); - - // Generate internal storage file - FileUtils.writeStringToFile(Path.of("/tmp/unittest/internalStorage.txt").toFile(), "Hello from internal storage", StandardCharsets.UTF_8); - - DefaultLogConsumer defaultLogConsumer = new DefaultLogConsumer(runContext); - // This is purely to showcase that no logs is sent as STDERR for now as CloudWatch doesn't seem to send such information. - Map logsWithIsStdErr = new HashMap<>(); - CommandsWrapper commandsWrapper = new CommandsWrapper(runContext) - .withCommands(ScriptService.scriptCommands(List.of("/bin/sh", "-c"), null, List.of( - "cat {{workingDir}}/{{internalStorageFile}}", - "cat {{workingDir}}/hello.txt", - "cat {{workingDir}}/hello.txt > {{workingDir}}/output.txt", - "echo '::{\"outputs\":{\"logOutput\":\"Hello World\"}}::'" - ))) - .withContainerImage("ghcr.io/kestra-io/awsbatch:latest") - .withLogConsumer(new AbstractLogConsumer() { - @Override - public void accept(String log, Boolean isStdErr) { - logsWithIsStdErr.put(log, isStdErr); - defaultLogConsumer.accept(log, isStdErr); - } - }); - RunnerResult run = scriptRunner().run(runContext, commandsWrapper, List.of("hello.txt"), List.of("output.txt")); - - // Exit code for successful job - assertThat(run.getExitCode(), is(0)); - - // Verify logs, we can't assert exact log entries as logs are sometimes grouped together by AWS CloudWatch - Set> logEntries = logsWithIsStdErr.entrySet(); - assertThat(logEntries.stream().filter(e -> e.getKey().startsWith("[JOB LOG]")).findFirst().orElseThrow().getValue(), is(false)); - assertThat(logEntries.stream().filter(e -> e.getKey().contains("Hello from internal storage")).findFirst().orElseThrow().getValue(), is(false)); - assertThat(logEntries.stream().filter(e -> e.getKey().contains("Hello World")).findFirst().orElseThrow().getValue(), is(false)); - - // Verify outputFiles - File outputFile = runContext.resolve(Path.of("output.txt")).toFile(); - assertThat(outputFile.exists(), is(true)); - assertThat(FileUtils.readFileToString(outputFile, StandardCharsets.UTF_8), is("Hello World")); - - assertThat(defaultLogConsumer.getOutputs().get("logOutput"), is("Hello World")); - } @Override protected ScriptRunner scriptRunner() { return AwsBatchScriptRunner.builder() .accessKeyId(accessKeyId) .secretKeyId(secretKeyId) - .s3Bucket(s3Bucket) + .bucket(s3Bucket) .region("eu-west-3") .computeEnvironmentArn("arn:aws:batch:eu-west-3:634784741179:compute-environment/FargateComputeEnvironment") .executionRoleArn("arn:aws:iam::634784741179:role/AWS-Batch-Role-For-Fargate") @@ -102,4 +35,9 @@ protected ScriptRunner scriptRunner() { .jobQueueArn("arn:aws:batch:eu-west-3:634784741179:job-queue/queue") .build(); } + + @Override + protected boolean needsToSpecifyWorkingDirectory() { + return true; + } }