diff --git a/backend/src/main/scala/cromwell/backend/backend.scala b/backend/src/main/scala/cromwell/backend/backend.scala index ea413c10367..18718f0e76f 100644 --- a/backend/src/main/scala/cromwell/backend/backend.scala +++ b/backend/src/main/scala/cromwell/backend/backend.scala @@ -139,6 +139,8 @@ object CommonBackendConfigurationAttributes { "default-runtime-attributes.noAddress", "default-runtime-attributes.docker", "default-runtime-attributes.queueArn", + "default-runtime-attributes.awsBatchRetryAttempts", + "default-runtime-attributes.ulimits", "default-runtime-attributes.failOnStderr", "slow-job-warning-time", "dockerhub", diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index c236f934661..8c2c3b72c79 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -418,6 +418,8 @@ docker { dockerhub.num-threads = 10 quay.num-threads = 10 alibabacloudcr.num-threads = 10 + ecr.num-threads = 10 + ecr-public.num-threads = 10 } } diff --git a/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala b/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala index 825fc8992ab..7454afe4800 100644 --- a/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala +++ b/dockerHashing/src/main/scala/cromwell/docker/DockerInfoActor.scala @@ -14,6 +14,7 @@ import cromwell.core.actor.StreamIntegration.{BackPressure, StreamContext} import cromwell.core.{Dispatcher, DockerConfiguration} import cromwell.docker.DockerInfoActor._ import cromwell.docker.registryv2.DockerRegistryV2Abstract +import cromwell.docker.registryv2.flows.aws.{AmazonEcr, AmazonEcrPublic} import cromwell.docker.registryv2.flows.alibabacloudcrregistry._ import cromwell.docker.registryv2.flows.dockerhub.DockerHubRegistry import cromwell.docker.registryv2.flows.google.GoogleRegistry @@ -236,7 +237,9 @@ object DockerInfoActor { ("dockerhub", { c: DockerRegistryConfig => new DockerHubRegistry(c) }), ("google", { c: DockerRegistryConfig => new GoogleRegistry(c) }), ("quay", { c: DockerRegistryConfig => new QuayRegistry(c) }), - ("alibabacloudcr", {c: DockerRegistryConfig => new AlibabaCloudCRRegistry(c)}) + ("alibabacloudcr", {c: DockerRegistryConfig => new AlibabaCloudCRRegistry(c)}), + ("ecr", {c: DockerRegistryConfig => new AmazonEcr(c)}), + ("ecr-public", {c: DockerRegistryConfig => new AmazonEcrPublic(c)}) ).traverse[ErrorOr, DockerRegistry]({ case (configPath, constructor) => DockerRegistryConfig.fromConfig(config.as[Config](configPath)).map(constructor) }).unsafe("Docker registry configuration") diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala index e03b3bee3c5..b4ae630751b 100644 --- a/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/DockerRegistryV2Abstract.scala @@ -60,7 +60,7 @@ object DockerRegistryV2Abstract { ) }) } - + // Placeholder exceptions that can be carried through IO before being converted to a DockerInfoFailedResponse private class Unauthorized() extends Exception private class NotFound() extends Exception @@ -76,6 +76,8 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi implicit val cs = IO.contextShift(ec) implicit val timer = IO.timer(ec) + protected val authorizationScheme: AuthScheme = AuthScheme.Bearer + /** * This is the main function. Given a docker context and an http client, retrieve information about the docker image. */ @@ -204,7 +206,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi * Request to get the manifest, using the auth token if provided */ private def manifestRequest(token: Option[String], imageId: DockerImageIdentifier): IO[Request[IO]] = { - val authorizationHeader = token.map(t => Authorization(Credentials.Token(AuthScheme.Bearer, t))) + val authorizationHeader = token.map(t => Authorization(Credentials.Token(authorizationScheme, t))) val request = Method.GET( buildManifestUri(imageId), List( @@ -235,13 +237,13 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi * The response can be of 2 sorts: * - A manifest (https://docs.docker.com/registry/spec/manifest-v2-2/#image-manifest-field-descriptions) * - A manifest list which contains a list of pointers to other manifests (https://docs.docker.com/registry/spec/manifest-v2-2/#manifest-list) - * + * * When a manifest list is returned, we need to pick one of the manifest pointers and make another request for that manifest. - * + * * Because the different manifests in the list are (supposed to be) variations of the same image over different platforms, * we simply pick the first one here since we only care about the approximate size, and we don't expect it to change drastically * between platforms. - * If that assumption turns out to be incorrect, a smarter decision may need to be made to choose the manifest to lookup. + * If that assumption turns out to be incorrect, a smarter decision may need to be made to choose the manifest to lookup. */ private def parseManifest(dockerImageIdentifier: DockerImageIdentifier, token: Option[String])(response: Response[IO])(implicit client: Client[IO]): IO[Option[DockerManifest]] = response match { case Status.Successful(r) if r.headers.exists(_.value.equalsIgnoreCase(ManifestV2MediaType)) => @@ -268,14 +270,14 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi } } - private def getDigestFromResponse(response: Response[IO]): IO[DockerHashResult] = response match { + protected def getDigestFromResponse(response: Response[IO]): IO[DockerHashResult] = response match { case Status.Successful(r) => extractDigestFromHeaders(r.headers) case Status.Unauthorized(_) => IO.raiseError(new Unauthorized) case Status.NotFound(_) => IO.raiseError(new NotFound) case failed => failed.as[String].flatMap(body => IO.raiseError(new Exception(s"Failed to get manifest: $body")) ) } - + private def extractDigestFromHeaders(headers: Headers) = { headers.find(a => a.toRaw.name.equals(DigestHeaderName)) match { case Some(digest) => IO.fromEither(DockerHashResult.fromString(digest.value).toEither) diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcr.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcr.scala new file mode 100644 index 00000000000..a073b9f4eb2 --- /dev/null +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcr.scala @@ -0,0 +1,42 @@ +package cromwell.docker.registryv2.flows.aws + +import cats.effect.IO +import cromwell.docker.{DockerImageIdentifier, DockerInfoActor, DockerRegistryConfig} +import org.http4s.AuthScheme +import org.http4s.client.Client +import software.amazon.awssdk.services.ecr.EcrClient + +import scala.compat.java8.OptionConverters._ +import scala.concurrent.Future + +class AmazonEcr(override val config: DockerRegistryConfig, ecrClient: EcrClient = EcrClient.create()) extends AmazonEcrAbstract(config) { + + override protected val authorizationScheme: AuthScheme = AuthScheme.Basic + + /** + * e.g 123456789012.dkr.ecr.us-east-1.amazonaws.com + */ + override protected def registryHostName(dockerImageIdentifier: DockerImageIdentifier): String = { + var hostname = dockerImageIdentifier.hostAsString + if (hostname.lastIndexOf("/").equals(hostname.length -1)) { + hostname = hostname.substring(0, hostname.length -1) + } + hostname + } + /** + * Returns true if this flow is able to process this docker image, + * false otherwise + */ + override def accepts(dockerImageIdentifier: DockerImageIdentifier): Boolean = dockerImageIdentifier.hostAsString.contains("amazonaws.com") + + override protected def getToken(dockerInfoContext: DockerInfoActor.DockerInfoContext)(implicit client: Client[IO]): IO[Option[String]] = { + val eventualMaybeToken = Future(ecrClient.getAuthorizationToken + .authorizationData() + .stream() + .findFirst() + .asScala + .map(_.authorizationToken())) + + IO.fromFuture(IO(eventualMaybeToken)) + } +} diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrAbstract.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrAbstract.scala new file mode 100644 index 00000000000..bc9ea61fd74 --- /dev/null +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrAbstract.scala @@ -0,0 +1,41 @@ +package cromwell.docker.registryv2.flows.aws + +import cats.effect.IO +import cromwell.docker.{DockerHashResult, DockerImageIdentifier, DockerInfoActor, DockerRegistryConfig} +import cromwell.docker.registryv2.DockerRegistryV2Abstract +import cromwell.docker.registryv2.flows.aws.EcrUtils.{EcrForbidden, EcrNotFound, EcrUnauthorized} +import org.apache.commons.codec.digest.DigestUtils +import org.http4s.{Header, Response, Status} + +abstract class AmazonEcrAbstract(override val config: DockerRegistryConfig) extends DockerRegistryV2Abstract(config) { + + /** + * Not used as getToken is overridden + */ + override protected def authorizationServerHostName(dockerImageIdentifier: DockerImageIdentifier): String = "" + + /** + * Not used as getToken is overridden + */ + override protected def buildTokenRequestHeaders(dockerInfoContext: DockerInfoActor.DockerInfoContext): List[Header] = List.empty + + /** + * Amazon ECR repositories don't have a digest header in responses so we must made it from the manifest body + */ + override protected def getDigestFromResponse(response: Response[IO]): IO[DockerHashResult] = response match { + case Status.Successful(r) => digestManifest(r.bodyText) + case Status.Unauthorized(_) => IO.raiseError(new EcrUnauthorized) + case Status.NotFound(_) => IO.raiseError(new EcrNotFound) + case Status.Forbidden(_) => IO.raiseError(new EcrForbidden) + case failed => failed.as[String].flatMap(body => IO.raiseError(new Exception(s"Failed to get manifest: $body"))) + } + + private def digestManifest(bodyText: fs2.Stream[IO, String]): IO[DockerHashResult] = { + bodyText + .compile + .string + .map(data => "sha256:"+DigestUtils.sha256Hex(data)) + .map(DockerHashResult.fromString) + .flatMap(IO.fromTry) + } +} diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublic.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublic.scala new file mode 100644 index 00000000000..797d4b70703 --- /dev/null +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublic.scala @@ -0,0 +1,36 @@ +package cromwell.docker.registryv2.flows.aws + +import cats.effect.IO +import cromwell.docker.{DockerImageIdentifier, DockerInfoActor, DockerRegistryConfig} +import org.http4s.client.Client +import software.amazon.awssdk.services.ecrpublic.EcrPublicClient +import software.amazon.awssdk.services.ecrpublic.model.GetAuthorizationTokenRequest + +import scala.concurrent.Future + + +class AmazonEcrPublic(override val config: DockerRegistryConfig, ecrClient: EcrPublicClient = EcrPublicClient.create()) extends AmazonEcrAbstract(config) { + /** + * public.ecr.aws + */ + override protected def registryHostName(dockerImageIdentifier: DockerImageIdentifier): String = "public.ecr.aws" + + /** + * Returns true if this flow is able to process this docker image, + * false otherwise + */ + override def accepts(dockerImageIdentifier: DockerImageIdentifier): Boolean = dockerImageIdentifier.hostAsString.contains("public.ecr.aws") + + + override protected def getToken(dockerInfoContext: DockerInfoActor.DockerInfoContext)(implicit client: Client[IO]): IO[Option[String]] = { + + val eventualMaybeToken: Future[Option[String]] = Future( + Option(ecrClient + .getAuthorizationToken(GetAuthorizationTokenRequest.builder().build()) + .authorizationData.authorizationToken() + ) + ) + + IO.fromFuture(IO(eventualMaybeToken)) + } +} diff --git a/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/EcrUtils.scala b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/EcrUtils.scala new file mode 100644 index 00000000000..9da904c8adc --- /dev/null +++ b/dockerHashing/src/main/scala/cromwell/docker/registryv2/flows/aws/EcrUtils.scala @@ -0,0 +1,9 @@ +package cromwell.docker.registryv2.flows.aws + +object EcrUtils { + + case class EcrUnauthorized() extends Exception + case class EcrNotFound() extends Exception + case class EcrForbidden() extends Exception + +} diff --git a/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublicSpec.scala b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublicSpec.scala new file mode 100644 index 00000000000..18a0a23232e --- /dev/null +++ b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrPublicSpec.scala @@ -0,0 +1,69 @@ +package cromwell.docker.registryv2.flows.aws + +import cats.effect.{IO, Resource} +import cromwell.core.TestKitSuite +import cromwell.docker.registryv2.DockerRegistryV2Abstract +import cromwell.docker.{DockerImageIdentifier, DockerInfoActor, DockerInfoRequest, DockerRegistryConfig} +import org.http4s.{Header, Headers, MediaType, Request, Response} +import org.http4s.client.Client +import org.http4s.headers.`Content-Type` +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.flatspec.AnyFlatSpecLike +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar +import software.amazon.awssdk.services.ecrpublic.model.{AuthorizationData, GetAuthorizationTokenRequest, GetAuthorizationTokenResponse} +import software.amazon.awssdk.services.ecrpublic.EcrPublicClient + +class AmazonEcrPublicSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with MockitoSugar with BeforeAndAfter with PrivateMethodTester { + behavior of "AmazonEcrPublic" + + val goodUri = "public.ecr.aws/amazonlinux/amazonlinux:latest" + val otherUri = "ubuntu:latest" + + + val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.ManifestV2MediaType).right.get + val contentType: Header = `Content-Type`(mediaType) + val mockEcrClient: EcrPublicClient = mock[EcrPublicClient] + implicit val mockIOClient: Client[IO] = Client({ _: Request[IO] => + // This response will have an empty body, so we need to be explicit about the typing: + Resource.pure[IO, Response[IO]](Response(headers = Headers.of(contentType))) : Resource[IO, Response[IO]] + }) + + val registry = new AmazonEcrPublic(DockerRegistryConfig.default, mockEcrClient) + + it should "Accept good URI" in { + val dockerImageIdentifier = DockerImageIdentifier.fromString(goodUri).get + registry.accepts(dockerImageIdentifier) shouldEqual true + } + + it should "NOT accept other URI" in { + val dockerImageIdentifier = DockerImageIdentifier.fromString(otherUri).get + registry.accepts(dockerImageIdentifier) shouldEqual false + } + + it should "have public.ecr.aws as registryHostName" in { + val registryHostNameMethod = PrivateMethod[String]('registryHostName) + registry invokePrivate registryHostNameMethod(DockerImageIdentifier.fromString(goodUri).get) shouldEqual "public.ecr.aws" + } + + it should "return expected auth token" in { + val token = "auth-token" + val imageId = DockerImageIdentifier.fromString(goodUri).get + val dockerInfoRequest = DockerInfoRequest(imageId) + val context = DockerInfoActor.DockerInfoContext(request = dockerInfoRequest, replyTo = emptyActor) + + when(mockEcrClient.getAuthorizationToken(any[GetAuthorizationTokenRequest]())) + .thenReturn(GetAuthorizationTokenResponse + .builder() + .authorizationData(AuthorizationData + .builder() + .authorizationToken(token) + .build()) + .build) + + val getTokenMethod = PrivateMethod[IO[Option[String]]]('getToken) + registry invokePrivate getTokenMethod(context, mockIOClient) ensuring(io => io.unsafeRunSync().get == token) + } +} diff --git a/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrSpec.scala b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrSpec.scala new file mode 100644 index 00000000000..5ddf98c7ffa --- /dev/null +++ b/dockerHashing/src/test/scala/cromwell/docker/registryv2/flows/aws/AmazonEcrSpec.scala @@ -0,0 +1,72 @@ +package cromwell.docker.registryv2.flows.aws + +import cats.effect.{IO, Resource} +import cromwell.core.TestKitSuite +import cromwell.docker.registryv2.DockerRegistryV2Abstract +import cromwell.docker.{DockerImageIdentifier, DockerInfoActor, DockerInfoRequest, DockerRegistryConfig} +import org.http4s.{AuthScheme, Header, Headers, MediaType, Request, Response} +import org.http4s.client.Client +import org.http4s.headers.`Content-Type` +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.flatspec.AnyFlatSpecLike +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar +import software.amazon.awssdk.services.ecr.EcrClient +import software.amazon.awssdk.services.ecr.model.{AuthorizationData, GetAuthorizationTokenResponse} + +class AmazonEcrSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with MockitoSugar with BeforeAndAfter with PrivateMethodTester{ + behavior of "AmazonEcr" + + val goodUri = "123456789012.dkr.ecr.us-east-1.amazonaws.com/amazonlinux/amazonlinux:latest" + val otherUri = "ubuntu:latest" + + val mediaType: MediaType = MediaType.parse(DockerRegistryV2Abstract.ManifestV2MediaType).right.get + val contentType: Header = `Content-Type`(mediaType) + val mockEcrClient: EcrClient = mock[EcrClient] + implicit val mockIOClient: Client[IO] = Client({ _: Request[IO] => + // This response will have an empty body, so we need to be explicit about the typing: + Resource.pure[IO, Response[IO]](Response(headers = Headers.of(contentType))) : Resource[IO, Response[IO]] + }) + + val registry = new AmazonEcr(DockerRegistryConfig.default, mockEcrClient) + + it should "accept good URI" in { + val dockerImageIdentifier = DockerImageIdentifier.fromString(goodUri).get + registry.accepts(dockerImageIdentifier) shouldEqual true + } + + it should "NOT accept other URI" in { + val dockerImageIdentifier = DockerImageIdentifier.fromString(otherUri).get + registry.accepts(dockerImageIdentifier) shouldEqual false + } + + it should "use Basic Auth Scheme" in { + val authSchemeMethod = PrivateMethod[AuthScheme]('authorizationScheme) + registry invokePrivate authSchemeMethod() shouldEqual AuthScheme.Basic + } + + it should "return 123456789012.dkr.ecr.us-east-1.amazonaws.com as registryHostName" in { + val registryHostNameMethod = PrivateMethod[String]('registryHostName) + registry invokePrivate registryHostNameMethod(DockerImageIdentifier.fromString(goodUri).get) shouldEqual "123456789012.dkr.ecr.us-east-1.amazonaws.com" + } + + it should "return expected auth token" in { + val token = "auth-token" + val imageId = DockerImageIdentifier.fromString(goodUri).get + val dockerInfoRequest = DockerInfoRequest(imageId) + val context = DockerInfoActor.DockerInfoContext(request = dockerInfoRequest, replyTo = emptyActor) + + when(mockEcrClient.getAuthorizationToken) + .thenReturn(GetAuthorizationTokenResponse + .builder() + .authorizationData(AuthorizationData + .builder() + .authorizationToken(token) + .build()) + .build) + + val getTokenMethod = PrivateMethod[IO[Option[String]]]('getToken) + registry invokePrivate getTokenMethod(context, mockIOClient) ensuring(io => io.unsafeRunSync().get == token) + } +} diff --git a/docs/LanguageSupport.md b/docs/LanguageSupport.md index 74f5ab238f6..d14d6751a42 100644 --- a/docs/LanguageSupport.md +++ b/docs/LanguageSupport.md @@ -26,7 +26,7 @@ As well as the changes to the WDL spec between draft-2 and 1.0, Cromwell also su ### CWL 1.0 Cromwell provides support for Common Workflow Language (CWL), beginning with the core spec, and most heavily used requirements. -If you spot a CWL feature that Cromwell doesn't support, please notify us using an issue on our github page! +If you spot a CWL feature that Cromwell doesn't support, please notify us using an issue on our [Jira page](https://broadworkbench.atlassian.net/secure/RapidBoard.jspa?rapidView=39&view=planning.nodetail&issueLimit=100)! ## Future Language Support diff --git a/docs/RuntimeAttributes.md b/docs/RuntimeAttributes.md index 70485339b3e..c31b6fbbf41 100644 --- a/docs/RuntimeAttributes.md +++ b/docs/RuntimeAttributes.md @@ -56,6 +56,9 @@ There are a number of additional runtime attributes that apply to the Google Clo - [useDockerImageCache](#usedockerimagecache) +### AWS Specific Attributes +- [awsBatchRetryAttempts](#awsBatchRetryAttempts) +- [ulimits](#ulimits) ## Expression support @@ -323,8 +326,6 @@ runtime { ``` - - ### `bootDiskSizeGb` In addition to working disks, Google Cloud allows specification of a boot disk size. This is the disk where the docker image itself is booted (**not the working directory of your task on the VM**). @@ -373,6 +374,54 @@ runtime { } ``` + +### `awsBatchRetryAttempts` + +*Default: _0_* + +This runtime attribute adds support to [*AWS Batch Automated Job Retries*](https://docs.aws.amazon.com/batch/latest/userguide/job_retries.html) which makes it possible to tackle transient job failures. For example, if a task fails due to a timeout from accessing an external service, then this option helps re-run the failed the task without having to re-run the entire workflow. It takes an Int, between 1 and 10, as a value that indicates the maximum number of times AWS Batch should retry a failed task. If the value 0 is passed, the [*Retry Strategy*](https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#retryStrategy) will not be added to the job definiton and the task will run just once. + +``` +runtime { + awsBatchRetryAttempts: integer +} +``` + + +### `ulimits` + +*Default: _empty_* + +A list of [`ulimits`](https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#containerProperties) values to set in the container. This parameter maps to `Ulimits` in the [Create a container](https://docs.docker.com/engine/api/v1.38/) section of the [Docker Remote API](https://docs.docker.com/engine/api/v1.38/) and the `--ulimit` option to [docker run](https://docs.docker.com/engine/reference/commandline/run/). + +``` +"ulimits": [ + { + "name": string, + "softLimit": integer, + "hardLimit": integer + } + ... +] +``` +Parameter description: + +- `name` + - The `type` of the `ulimit`. + - Type: String + - Required: Yes, when `ulimits` is used. + +- `softLimit` + - The soft limit for the `ulimit` type. + - Type: Integer + - Required: Yes, when `ulimits` is used. + +- `hardLimit` + - The hard limit for the `ulimit` type. + - Type: Integer + - Required: Yes, when `ulimits` is used. + + #### How to Setup Configure your Google network to use "Private Google Access". This will allow your VMs to access Google Services including Google Container Registry, as well as Dockerhub images. diff --git a/docs/backends/Google.md b/docs/backends/Google.md index cf694125de6..907b90c10f6 100644 --- a/docs/backends/Google.md +++ b/docs/backends/Google.md @@ -11,7 +11,7 @@ The instructions below assume you have created a Google Cloud Storage bucket and **Configuring Authentication** -The `google` stanza in the Cromwell configuration file defines how to authenticate to Google. There are four different +The `google` stanza in the Cromwell configuration file defines how to authenticate to Google. There are five different authentication schemes that might be used: * `application_default` (default, recommended) - Use [application default](https://developers.google.com/identity/protocols/application-default-credentials) credentials. @@ -48,7 +48,7 @@ the `genomics` and `filesystems.gcs` sections within a Google configuration bloc The auth for the `genomics` section governs the interactions with Google itself, while `filesystems.gcs` governs the localization of data into and out of GCE VMs. -**Application Default Credentials** +***Application Default Credentials*** By default, application default credentials will be used. Only `name` and `scheme` are required for application default credentials. @@ -59,7 +59,7 @@ $ gcloud auth login $ gcloud config set project my-project ``` -**Service Account** +***Service Account*** First create a new service account through the [API Credentials](https://console.developers.google.com/apis/credentials) page. Go to **Create credentials -> Service account key**. Then in the **Service account** dropdown select **New service account**. Fill in a name (e.g. `my-account`), and select key type of JSON. diff --git a/docs/tutorials/MetadataEndpoint.md b/docs/tutorials/MetadataEndpoint.md index 2883c1e0528..fd463f270f4 100644 --- a/docs/tutorials/MetadataEndpoint.md +++ b/docs/tutorials/MetadataEndpoint.md @@ -25,4 +25,4 @@ After completing this tutorial you might find the following page interesting: _Drop us a line in the [Forum](https://gatkforums.broadinstitute.org/wdl/categories/ask-the-wdl-team) if you have a question._ \*\*\* **UNDER CONSTRUCTION** \*\*\* -[![Pennywell pig in red wellies - Richard Austin Images](http://www.richardaustinimages.com/wp-content/uploads/2015/04/fluffyAustin_Pigets_Wellies-500x395.jpg)](http://www.richardaustinimages.com/product/pennywell-pigs-under-umbrella-2/) +[![Pennywell pig in red wellies - Richard Austin Images](https://static.wixstatic.com/media/b0d56a_6b627e45766d44fa8b2714f5d7860c84~mv2.jpg/v1/fill/w_787,h_551,al_c,q_50,usm_0.66_1.00_0.01/b0d56a_6b627e45766d44fa8b2714f5d7860c84~mv2.jpg) diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/AmazonS3Factory.java b/filesystems/s3/src/main/java/org/lerch/s3fs/AmazonS3Factory.java index 987e8d71c6a..9d525587642 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/AmazonS3Factory.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/AmazonS3Factory.java @@ -1,20 +1,16 @@ package org.lerch.s3fs; - + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.*; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3ClientBuilder; import software.amazon.awssdk.services.s3.S3Configuration; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import java.net.URI; import java.util.Properties; diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3AccessControlList.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3AccessControlList.java index 0d908001ed3..1f219164707 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3AccessControlList.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3AccessControlList.java @@ -1,14 +1,14 @@ package org.lerch.s3fs; -import static java.lang.String.format; +import software.amazon.awssdk.services.s3.model.Grant; +import software.amazon.awssdk.services.s3.model.Owner; +import software.amazon.awssdk.services.s3.model.Permission; import java.nio.file.AccessDeniedException; import java.nio.file.AccessMode; import java.util.EnumSet; -import software.amazon.awssdk.services.s3.model.Grant; -import software.amazon.awssdk.services.s3.model.Owner; -import software.amazon.awssdk.services.s3.model.Permission; +import static java.lang.String.format; public class S3AccessControlList { private String fileStoreName; diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileChannel.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileChannel.java index 2f8bd41ff8c..e849294791d 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileChannel.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileChannel.java @@ -1,8 +1,15 @@ package org.lerch.s3fs; import org.apache.tika.Tika; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import java.io.*; +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; @@ -14,13 +21,6 @@ import java.util.HashSet; import java.util.Set; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.core.ResponseInputStream; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.GetObjectResponse; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.S3Object; - import static java.lang.String.format; public class S3FileChannel extends FileChannel implements S3Channel { @@ -46,7 +46,7 @@ else if (!exists && !this.options.contains(StandardOpenOption.CREATE_NEW) && boolean removeTempFile = true; try { if (exists) { - try (ResponseInputStream byteStream = path.getFileSystem() + try (ResponseInputStream byteStream = path.getFileStore() .getClient() .getObject(GetObjectRequest .builder() @@ -171,7 +171,7 @@ protected void sync() throws IOException { .contentLength(length) .contentType(new Tika().detect(stream, path.getFileName().toString())); - path.getFileSystem().getClient().putObject(builder.build(), RequestBody.fromInputStream(stream, length)); + path.getFileStore().getClient().putObject(builder.build(), RequestBody.fromInputStream(stream, length)); } } } diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java index d8082d0605b..8e15999180a 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileStore.java @@ -1,28 +1,37 @@ package org.lerch.s3fs; +import org.lerch.s3fs.util.S3ClientStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.model.*; + import java.io.IOException; +import java.net.URI; import java.nio.file.FileStore; import java.nio.file.attribute.FileAttributeView; import java.nio.file.attribute.FileStoreAttributeView; import java.util.Date; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.Bucket; -import software.amazon.awssdk.services.s3.model.GetBucketAclRequest; -import software.amazon.awssdk.services.s3.model.HeadBucketRequest; -import software.amazon.awssdk.services.s3.model.ListBucketsRequest; -import software.amazon.awssdk.services.s3.model.NoSuchBucketException; -import software.amazon.awssdk.services.s3.model.Owner; -import com.google.common.collect.ImmutableList; - +/** + * In S3 a filestore translates to a bucket + */ public class S3FileStore extends FileStore implements Comparable { private S3FileSystem fileSystem; private String name; + private S3Client defaultClient; + private final Logger logger = LoggerFactory.getLogger("S3FileStore"); public S3FileStore(S3FileSystem s3FileSystem, String name) { this.fileSystem = s3FileSystem; this.name = name; + // the default client can be used for getBucketLocation operations + this.defaultClient = S3Client.builder().endpointOverride(URI.create("https://s3.us-east-1.amazonaws.com")).region(Region.US_EAST_1).build(); } @Override @@ -111,8 +120,8 @@ private boolean hasBucket(String bucketName) { // model as HeadBucket is now required boolean bucket = false; try { - getClient().headBucket(HeadBucketRequest.builder().bucket(bucketName).build()); - bucket = true; + getClient().headBucket(HeadBucketRequest.builder().bucket(bucketName).build()); + bucket = true; }catch(NoSuchBucketException nsbe) {} return bucket; } @@ -121,8 +130,13 @@ public S3Path getRootDirectory() { return new S3Path(fileSystem, "/" + this.name()); } - private S3Client getClient() { - return fileSystem.getClient(); + /** + * Gets a client suitable for this FileStore (bucket) including configuring the correct region endpoint. If no client + * exists one will be constructed and cached. + * @return a client + */ + public S3Client getClient() { + return S3ClientStore.getInstance().getClientForBucketName(this.name); } public Owner getOwner() { diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystem.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystem.java index 4fadae9518b..0038a787c64 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystem.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystem.java @@ -1,20 +1,16 @@ package org.lerch.s3fs; -import static org.lerch.s3fs.S3Path.PATH_SEPARATOR; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Bucket; import java.io.IOException; -import java.nio.file.FileStore; -import java.nio.file.FileSystem; -import java.nio.file.Path; -import java.nio.file.PathMatcher; -import java.nio.file.WatchService; +import java.nio.file.*; import java.nio.file.attribute.UserPrincipalLookupService; import java.util.Set; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.Bucket; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; +import static org.lerch.s3fs.S3Path.PATH_SEPARATOR; /** * S3FileSystem with a concrete client configured and ready to use. @@ -63,7 +59,7 @@ public boolean isReadOnly() { @Override public String getSeparator() { - return S3Path.PATH_SEPARATOR; + return PATH_SEPARATOR; } @Override @@ -113,9 +109,15 @@ public WatchService newWatchService() throws IOException { throw new UnsupportedOperationException(); } - public S3Client getClient() { - return client; - } +// /** +// * Deprecated: since SDKv2 many S3 operations need to be signed with a client using the same Region as the location +// * of the bucket. Prefer S3Path.client() instead. +// * @return +// */ +// @Deprecated +// public S3Client getClient() { +// return client; +// } /** * get the endpoint associated with this fileSystem. diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystemProvider.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystemProvider.java index 1da53532d2b..208db86d1ba 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystemProvider.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3FileSystemProvider.java @@ -1,11 +1,11 @@ package org.lerch.s3fs; -import org.apache.commons.lang3.tuple.ImmutablePair; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import org.apache.commons.lang3.tuple.ImmutablePair; import org.lerch.s3fs.attribute.S3BasicFileAttributeView; import org.lerch.s3fs.attribute.S3BasicFileAttributes; import org.lerch.s3fs.attribute.S3PosixFileAttributeView; @@ -37,7 +37,8 @@ import static com.google.common.collect.Sets.difference; import static java.lang.String.format; -import static java.lang.Thread.*; +import static java.lang.Thread.currentThread; +import static java.lang.Thread.sleep; import static org.lerch.s3fs.AmazonS3Factory.*; /** @@ -287,7 +288,10 @@ public S3FileSystem getFileSystem(URI uri) { if (fileSystems.containsKey(key)) { return fileSystems.get(key); } else { - throw new FileSystemNotFoundException("S3 filesystem not yet created. Use newFileSystem() instead"); + final String scheme = uri.getScheme(); + final String uriString = uri.toString(); + uriString.replace(scheme, "https://"); + return (S3FileSystem) newFileSystem(uri, Collections.emptyMap()); } } @@ -338,7 +342,7 @@ public InputStream newInputStream(Path path, OpenOption... options) throws IOExc try { ResponseInputStream res = s3Path - .getFileSystem() + .getFileStore() .getClient() .getObject(GetObjectRequest .builder() @@ -385,7 +389,7 @@ public void createDirectory(Path dir, FileAttribute... attrs) throws IOExcept Bucket bucket = s3Path.getFileStore().getBucket(); String bucketName = s3Path.getFileStore().name(); if (bucket == null) { - s3Path.getFileSystem().getClient().createBucket(CreateBucketRequest.builder().bucket(bucketName).build()); + s3Path.getFileStore().getClient().createBucket(CreateBucketRequest.builder().bucket(bucketName).build()); } // create the object as directory PutObjectRequest.Builder builder = PutObjectRequest.builder(); @@ -393,7 +397,7 @@ public void createDirectory(Path dir, FileAttribute... attrs) throws IOExcept builder.bucket(bucketName) .key(directoryKey) .contentLength(0L); - s3Path.getFileSystem().getClient().putObject(builder.build(), RequestBody.fromBytes(new byte[0])); + s3Path.getFileStore().getClient().putObject(builder.build(), RequestBody.fromBytes(new byte[0])); } @Override @@ -406,9 +410,9 @@ public void delete(Path path) throws IOException { String key = s3Path.getKey(); String bucketName = s3Path.getFileStore().name(); - s3Path.getFileSystem().getClient().deleteObject(DeleteObjectRequest.builder().bucket(bucketName).key(key).build()); + s3Path.getFileStore().getClient().deleteObject(DeleteObjectRequest.builder().bucket(bucketName).key(key).build()); // we delete the two objects (sometimes exists the key '/' and sometimes not) - s3Path.getFileSystem().getClient().deleteObject(DeleteObjectRequest.builder().bucket(bucketName).key(key + "/").build()); + s3Path.getFileStore().getClient().deleteObject(DeleteObjectRequest.builder().bucket(bucketName).key(key + "/").build()); } @Override @@ -438,7 +442,8 @@ public void copy(Path source, Path target, CopyOption... options) throws IOExcep String keySource = s3Source.getKey(); String bucketNameTarget = s3Target.getFileStore().name(); String keyTarget = s3Target.getKey(); - s3Source.getFileSystem() + // for a cross region copy the client must be for the target (region) not the source region + s3Target.getFileStore() .getClient() .copyObject(CopyObjectRequest.builder() .sourceBucket(bucketNameOrigin) @@ -460,7 +465,7 @@ public void copy(Path source, Path target, CopyOption... options) throws IOExcep private void multiPartCopy(S3Path source, long objectSize, S3Path target, CopyOption... options) { log.info(() -> "Attempting multipart copy as part of call cache hit: source = " + source + ", objectSize = " + objectSize + ", target = " + target + ", options = " + Arrays.deepToString(options)); - S3Client s3Client = target.getFileSystem().getClient(); + S3Client s3Client = target.getFileStore().getClient(); final CreateMultipartUploadRequest createMultipartUploadRequest = CreateMultipartUploadRequest.builder() .bucket(target.getFileStore().name()) @@ -597,7 +602,7 @@ private void multiPartCopy(S3Path source, long objectSize, S3Path target, CopyOp */ private long objectSize(S3Path object) { - S3Client s3Client = object.getFileSystem().getClient(); + S3Client s3Client = object.getFileStore().getClient(); final String bucket = object.getFileStore().name(); final String key = object.getKey(); final HeadObjectResponse headObjectResponse = s3Client.headObject(HeadObjectRequest.builder() @@ -659,7 +664,7 @@ public void checkAccess(Path path, AccessMode... modes) throws IOException { String key = s3Utils.getS3ObjectSummary(s3Path).key(); String bucket = s3Path.getFileStore().name(); S3AccessControlList accessControlList = - new S3AccessControlList(bucket, key, s3Path.getFileSystem().getClient().getObjectAcl(GetObjectAclRequest.builder().bucket(bucket).key(key).build()).grants(), s3Path.getFileStore().getOwner()); + new S3AccessControlList(bucket, key, s3Path.getFileStore().getClient().getObjectAcl(GetObjectAclRequest.builder().bucket(bucket).key(key).build()).grants(), s3Path.getFileStore().getOwner()); accessControlList.checkAccess(modes); } diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3Iterator.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3Iterator.java index 3d8f5cb8072..803b5275892 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3Iterator.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3Iterator.java @@ -1,21 +1,15 @@ package org.lerch.s3fs; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.NoSuchElementException; -import java.util.Set; - +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.lerch.s3fs.util.S3Utils; import software.amazon.awssdk.services.s3.model.CommonPrefix; import software.amazon.awssdk.services.s3.model.ListObjectsRequest; import software.amazon.awssdk.services.s3.model.ListObjectsResponse; import software.amazon.awssdk.services.s3.model.S3Object; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import org.lerch.s3fs.util.S3Utils; + +import java.nio.file.Path; +import java.util.*; /** * S3 iterator over folders at first level. @@ -50,7 +44,7 @@ public S3Iterator(S3FileStore fileStore, String key, boolean incremental) { this.fileStore = fileStore; this.fileSystem = fileStore.getFileSystem(); this.key = key; - this.current = fileSystem.getClient().listObjects(listObjectsRequest); + this.current = fileStore.getClient().listObjects(listObjectsRequest); this.incremental = incremental; loadObjects(); } @@ -69,7 +63,7 @@ public S3Path next() { .marker(current.nextMarker()) .build(); - this.current = fileSystem.getClient().listObjects(request); + this.current = fileStore.getClient().listObjects(request); loadObjects(); } if (cursor == size) diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3Path.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3Path.java index 271a1b31e72..1382c16c3ad 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3Path.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3Path.java @@ -4,19 +4,19 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import org.lerch.s3fs.attribute.S3BasicFileAttributes; +import software.amazon.awssdk.services.s3.S3Client; import java.io.File; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URI; -import java.net.URL; import java.net.URLDecoder; import java.nio.file.*; import java.util.Iterator; import java.util.List; -import java.util.Map; import static com.google.common.collect.Iterables.*; +import static com.google.common.collect.Iterables.concat; import static java.lang.String.format; public class S3Path implements Path { diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/S3SeekableByteChannel.java b/filesystems/s3/src/main/java/org/lerch/s3fs/S3SeekableByteChannel.java index 31318d85c25..5f0ff571521 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/S3SeekableByteChannel.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/S3SeekableByteChannel.java @@ -1,8 +1,11 @@ package org.lerch.s3fs; -import static java.lang.String.format; +import org.apache.tika.Tika; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import java.io.ByteArrayInputStream; import java.io.BufferedInputStream; import java.io.IOException; import java.io.InputStream; @@ -13,21 +16,14 @@ import java.util.HashSet; import java.util.Set; -import org.apache.tika.Tika; - -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.GetObjectResponse; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.S3Object; +import static java.lang.String.format; public class S3SeekableByteChannel implements SeekableByteChannel, S3Channel { - private S3Path path; - private Set options; - private SeekableByteChannel seekable; - private Path tempFile; + private final S3Path path; + private final Set options; + private final SeekableByteChannel seekable; + private final Path tempFile; /** * Open or creates a file, returning a seekable byte channel @@ -52,7 +48,7 @@ else if (!exists && !this.options.contains(StandardOpenOption.CREATE_NEW) && boolean removeTempFile = true; try { if (exists) { - try (InputStream byteStream = path.getFileSystem().getClient() + try (InputStream byteStream = path.getFileStore().getClient() .getObject(GetObjectRequest.builder().bucket(path.getFileStore().getBucket().name()).key(key).build())) { Files.copy(byteStream, tempFile, StandardCopyOption.REPLACE_EXISTING); } @@ -115,7 +111,7 @@ protected void sync() throws IOException { builder.bucket(path.getFileStore().name()); builder.key(path.getKey()); - S3Client client = path.getFileSystem().getClient(); + S3Client client = path.getFileStore().getClient(); client.putObject(builder.build(), RequestBody.fromInputStream(stream, length)); } diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3BasicFileAttributes.java b/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3BasicFileAttributes.java index 40e02b17abb..895f5109fb2 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3BasicFileAttributes.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3BasicFileAttributes.java @@ -1,10 +1,10 @@ package org.lerch.s3fs.attribute; -import static java.lang.String.format; - import java.nio.file.attribute.BasicFileAttributes; import java.nio.file.attribute.FileTime; +import static java.lang.String.format; + public class S3BasicFileAttributes implements BasicFileAttributes { private final FileTime lastModifiedTime; private final long size; diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3PosixFileAttributes.java b/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3PosixFileAttributes.java index fbe5efac57c..6ffdee62f20 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3PosixFileAttributes.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/attribute/S3PosixFileAttributes.java @@ -3,8 +3,6 @@ import java.nio.file.attribute.*; import java.util.Set; -import static java.lang.String.format; - public class S3PosixFileAttributes extends S3BasicFileAttributes implements PosixFileAttributes { private UserPrincipal userPrincipal; diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/util/S3ClientStore.java b/filesystems/s3/src/main/java/org/lerch/s3fs/util/S3ClientStore.java new file mode 100644 index 00000000000..818a0004a3e --- /dev/null +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/util/S3ClientStore.java @@ -0,0 +1,100 @@ +package org.lerch.s3fs.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.HeadBucketResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A Singleton cache of clients for buckets configured for the region of those buckets + */ +public class S3ClientStore { + + private static final S3ClientStore instance = new S3ClientStore(); + + public static final S3Client DEFAULT_CLIENT = S3Client.builder().endpointOverride(URI.create("https://s3.us-east-1.amazonaws.com")).region(Region.US_EAST_1).build(); + + private final Map bucketToClientMap = Collections.synchronizedMap(new HashMap<>()); + + Logger logger = LoggerFactory.getLogger("S3ClientStore"); + + + private S3ClientStore(){} + + public static S3ClientStore getInstance() { return instance; } + + public S3Client getClientForBucketName( String bucketName ) { + logger.debug("obtaining client for bucket '{}'", bucketName); + if (bucketName == null || bucketName.trim().equals("")) { + return DEFAULT_CLIENT; + } + + return bucketToClientMap.computeIfAbsent(bucketName, this::generateClient); + } + + /** + * Generate a client for the named bucket using a default client to determine the location of the named client + * @param bucketName the named of the bucket to make the client for + * @return an S3 client appropriate for the region of the named bucket + */ + protected S3Client generateClient(String bucketName){ + return this.generateClient(bucketName, DEFAULT_CLIENT); + } + + /** + * Generate a client for the named bucket using a default client to determine the location of the named client + * @param bucketName the named of the bucket to make the client for + * @param locationClient the client used to determine the location of the named bucket, recommend using DEFAULT_CLIENT + * @return an S3 client appropriate for the region of the named bucket + */ + protected S3Client generateClient (String bucketName, S3Client locationClient) { + logger.info("generating client for bucket: '{}'", bucketName); + S3Client bucketSpecificClient; + try { + logger.info("determining bucket location with getBucketLocation"); + String bucketLocation = locationClient.getBucketLocation(builder -> builder.bucket(bucketName)).locationConstraintAsString(); + + bucketSpecificClient = this.clientForRegion(bucketLocation); + + } catch (S3Exception e) { + if(e.statusCode() == 403) { + logger.info("Cannot determine location of '{}' bucket directly. Attempting to obtain bucket location with headBucket operation", bucketName); + try { + final HeadBucketResponse headBucketResponse = locationClient.headBucket(builder -> builder.bucket(bucketName)); + bucketSpecificClient = this.clientForRegion(headBucketResponse.sdkHttpResponse().firstMatchingHeader("x-amz-bucket-region").orElseThrow()); + } catch (S3Exception e2) { + if (e2.statusCode() == 301) { + bucketSpecificClient = this.clientForRegion(e2.awsErrorDetails().sdkHttpResponse().firstMatchingHeader("x-amz-bucket-region").orElseThrow()); + } else { + throw e2; + } + } + } else { + throw e; + } + } + + if (bucketSpecificClient == null) { + logger.warn("Unable to determine the region of bucket: '{}'. Generating a client for the current region.", bucketName); + bucketSpecificClient = S3Client.create(); + } + + return bucketSpecificClient; + } + + private S3Client clientForRegion(String regionString){ + // It may be useful to further cache clients for regions although at some point clients for buckets may need to be + // specialized beyond just region end points. + Region region = regionString.equals("") ? Region.US_EAST_1 : Region.of(regionString); + logger.info("bucket region is: '{}'", region.id()); + return S3Client.builder().region(region).build(); + } + +} diff --git a/filesystems/s3/src/main/java/org/lerch/s3fs/util/S3Utils.java b/filesystems/s3/src/main/java/org/lerch/s3fs/util/S3Utils.java index e8586ef4d97..4f4a3d24996 100644 --- a/filesystems/s3/src/main/java/org/lerch/s3fs/util/S3Utils.java +++ b/filesystems/s3/src/main/java/org/lerch/s3fs/util/S3Utils.java @@ -1,22 +1,14 @@ package org.lerch.s3fs.util; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.HeadObjectRequest; -import software.amazon.awssdk.services.s3.model.HeadObjectResponse; -import software.amazon.awssdk.services.s3.model.GetObjectAclRequest; -import software.amazon.awssdk.services.s3.model.GetObjectAclResponse; -import software.amazon.awssdk.services.s3.model.Grant; -import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; -import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; -import software.amazon.awssdk.services.s3.model.Owner; -import software.amazon.awssdk.services.s3.model.Permission; -import software.amazon.awssdk.services.s3.model.S3Object; -import software.amazon.awssdk.services.s3.model.S3Exception; import com.google.common.collect.Sets; -import org.lerch.s3fs.attribute.S3BasicFileAttributes; import org.lerch.s3fs.S3Path; +import org.lerch.s3fs.attribute.S3BasicFileAttributes; import org.lerch.s3fs.attribute.S3PosixFileAttributes; import org.lerch.s3fs.attribute.S3UserPrincipal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.*; import java.nio.file.NoSuchFileException; import java.nio.file.attribute.FileTime; @@ -24,12 +16,12 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -import java.util.concurrent.TimeUnit; /** * Utilities to work with Amazon S3 Objects. */ public class S3Utils { + Logger log = LoggerFactory.getLogger("S3Utils"); /** * Get the {@link S3Object} that represent this Path or her first child if this path not exists @@ -41,18 +33,24 @@ public class S3Utils { public S3Object getS3ObjectSummary(S3Path s3Path) throws NoSuchFileException { String key = s3Path.getKey(); String bucketName = s3Path.getFileStore().name(); - S3Client client = s3Path.getFileSystem().getClient(); + S3Client client = s3Path.getFileStore().getClient(); // try to find the element with the current key (maybe with end slash or maybe not.) try { HeadObjectResponse metadata = client.headObject(HeadObjectRequest.builder().bucket(bucketName).key(key).build()); - GetObjectAclResponse acl = client.getObjectAcl(GetObjectAclRequest.builder().bucket(bucketName).key(key).build()); + Owner objectOwner = Owner.builder().build(); + try { + GetObjectAclResponse acl = client.getObjectAcl(GetObjectAclRequest.builder().bucket(bucketName).key(key).build()); + objectOwner = acl.owner(); + } catch (S3Exception e2){ + log.warn("Unable to determine the owner of object: '{}', setting owner as empty", s3Path); + } S3Object.Builder builder = S3Object.builder(); builder .key(key) .lastModified(metadata.lastModified()) .eTag(metadata.eTag()) - .owner(acl.owner()) + .owner(objectOwner) .size(metadata.contentLength()) .storageClass(metadata.storageClassAsString()); @@ -63,9 +61,9 @@ public S3Object getS3ObjectSummary(S3Path s3Path) throws NoSuchFileException { } // if not found (404 err) with the original key. - // try to find the elment as a directory. + // try to find the element as a directory. try { - // is a virtual directory + // is a virtual directory (S3 prefix) ListObjectsV2Request.Builder request = ListObjectsV2Request.builder(); request.bucket(bucketName); String keyFolder = key; @@ -111,7 +109,7 @@ public S3PosixFileAttributes getS3PosixFileAttributes(S3Path s3Path) throws NoSu Set permissions = null; if (!attrs.isDirectory()) { - S3Client client = s3Path.getFileSystem().getClient(); + S3Client client = s3Path.getFileStore().getClient(); GetObjectAclResponse acl = client.getObjectAcl(GetObjectAclRequest.builder().bucket(bucketName).key(key).build()); Owner owner = acl.owner(); diff --git a/filesystems/s3/src/main/scala/cromwell/filesystems/s3/batch/S3BatchCommandBuilder.scala b/filesystems/s3/src/main/scala/cromwell/filesystems/s3/batch/S3BatchCommandBuilder.scala index 8058d897975..645600d1613 100644 --- a/filesystems/s3/src/main/scala/cromwell/filesystems/s3/batch/S3BatchCommandBuilder.scala +++ b/filesystems/s3/src/main/scala/cromwell/filesystems/s3/batch/S3BatchCommandBuilder.scala @@ -30,9 +30,11 @@ */ package cromwell.filesystems.s3.batch -import cromwell.core.io.{IoCommandBuilder, PartialIoCommandBuilder} +import cromwell.core.io.{IoCommandBuilder, IoContentAsStringCommand, IoIsDirectoryCommand, IoReadLinesCommand, IoWriteCommand, PartialIoCommandBuilder} +import cromwell.core.path.BetterFileMethods.OpenOptions import cromwell.core.path.Path import cromwell.filesystems.s3.S3Path +import org.slf4j.{Logger, LoggerFactory} import scala.util.Try @@ -40,6 +42,29 @@ import scala.util.Try * Generates commands for IO operations on S3 */ private case object PartialS3BatchCommandBuilder extends PartialIoCommandBuilder { + val Log: Logger = LoggerFactory.getLogger(PartialS3BatchCommandBuilder.getClass) + + + override def contentAsStringCommand: PartialFunction[(Path, Option[Int], Boolean), Try[IoContentAsStringCommand]] = { + Log.debug("call to contentAsStringCommand but PartialFunction not implemented, falling back to super") + super.contentAsStringCommand + } + + override def writeCommand: PartialFunction[(Path, String, OpenOptions, Boolean), Try[IoWriteCommand]] = { + Log.debug("call to writeCommand but PartialFunction not implemented, falling back to super") + super.writeCommand + } + + override def isDirectoryCommand: PartialFunction[Path, Try[IoIsDirectoryCommand]] = { + Log.debug("call to isDirectoryCommand but PartialFunction not implemented, falling back to super") + super.isDirectoryCommand + } + + override def readLinesCommand: PartialFunction[Path, Try[IoReadLinesCommand]] = { + Log.debug("call to readLinesCommand but PartialFunction not implemented, falling back to super") + super.readLinesCommand + } + override def sizeCommand: PartialFunction[Path, Try[S3BatchSizeCommand]] = { case path: S3Path => Try(S3BatchSizeCommand(path)) } diff --git a/filesystems/s3/src/test/scala/org/lerch/s3fs/util/S3ClientStoreTest.java b/filesystems/s3/src/test/scala/org/lerch/s3fs/util/S3ClientStoreTest.java new file mode 100644 index 00000000000..6f419d0c335 --- /dev/null +++ b/filesystems/s3/src/test/scala/org/lerch/s3fs/util/S3ClientStoreTest.java @@ -0,0 +1,167 @@ +package org.lerch.s3fs.util; + +import junit.framework.TestCase; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InOrder; +import org.mockito.Mock; +import static org.mockito.Mockito.*; + +import org.mockito.Spy; +import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.awscore.exception.AwsErrorDetails; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetBucketLocationResponse; +import software.amazon.awssdk.services.s3.model.HeadBucketResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import java.util.NoSuchElementException; +import java.util.function.Consumer; + +@RunWith(MockitoJUnitRunner.class) +public class S3ClientStoreTest extends TestCase { + + S3ClientStore instance; + + @Mock + S3Client mockClient; + + @Spy + final S3ClientStore spyInstance = S3ClientStore.getInstance(); + + @Before + public void setUp() throws Exception { + super.setUp(); + instance = S3ClientStore.getInstance(); + } + + @Test + public void testGetInstanceReturnsSingleton() { + assertSame(S3ClientStore.getInstance(), instance); + } + + @Test + public void testGetClientForNullBucketName() { + assertEquals(S3ClientStore.DEFAULT_CLIENT, instance.getClientForBucketName(null)); + } + + @Test + public void testGetClientForEmptyBucketName() { + assertEquals(S3ClientStore.DEFAULT_CLIENT, instance.getClientForBucketName("")); + assertEquals(S3ClientStore.DEFAULT_CLIENT, instance.getClientForBucketName(" ")); + } + + @Test + public void testGenerateClientWithNoErrors() { + when(mockClient.getBucketLocation(any(Consumer.class))) + .thenReturn(GetBucketLocationResponse.builder().locationConstraint("us-west-2").build()); + final S3Client s3Client = instance.generateClient("test-bucket", mockClient); + assertNotNull(s3Client); + + ; + } + + @Test + public void testGenerateClientWith403Response() { + // when you get a forbidden response from getBucketLocation + when(mockClient.getBucketLocation(any(Consumer.class))).thenThrow( + S3Exception.builder().statusCode(403).build() + ); + // you should fall back to a head bucket attempt + when(mockClient.headBucket(any(Consumer.class))) + .thenReturn((HeadBucketResponse) HeadBucketResponse.builder() + .sdkHttpResponse(SdkHttpResponse.builder() + .putHeader("x-amz-bucket-region", "us-west-2") + .build()) + .build()); + + // which should get you a client + final S3Client s3Client = instance.generateClient("test-bucket", mockClient); + assertNotNull(s3Client); + + final InOrder inOrder = inOrder(mockClient); + inOrder.verify(mockClient).getBucketLocation(any(Consumer.class)); + inOrder.verify(mockClient).headBucket(any(Consumer.class)); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void testGenerateClientWith403Then301Responses(){ + // when you get a forbidden response from getBucketLocation + when(mockClient.getBucketLocation(any(Consumer.class))).thenThrow( + S3Exception.builder().statusCode(403).build() + ); + // and you get a 301 response on headBucket + when(mockClient.headBucket(any(Consumer.class))).thenThrow( + S3Exception.builder() + .statusCode(301) + .awsErrorDetails(AwsErrorDetails.builder() + .sdkHttpResponse(SdkHttpResponse.builder() + .putHeader("x-amz-bucket-region", "us-west-2") + .build()) + .build()) + .build() + ); + + // then you should be able to get a client as long as the error response header contains the region + final S3Client s3Client = instance.generateClient("test-bucket", mockClient); + assertNotNull(s3Client); + + final InOrder inOrder = inOrder(mockClient); + inOrder.verify(mockClient).getBucketLocation(any(Consumer.class)); + inOrder.verify(mockClient).headBucket(any(Consumer.class)); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void testGenerateClientWith403Then301ResponsesNoHeader(){ + // when you get a forbidden response from getBucketLocation + when(mockClient.getBucketLocation(any(Consumer.class))).thenThrow( + S3Exception.builder().statusCode(403).build() + ); + // and you get a 301 response on headBucket but no header for region + when(mockClient.headBucket(any(Consumer.class))).thenThrow( + S3Exception.builder() + .statusCode(301) + .awsErrorDetails(AwsErrorDetails.builder() + .sdkHttpResponse(SdkHttpResponse.builder() + .build()) + .build()) + .build() + ); + + // then you should get a NoSuchElement exception when you try to get the header + try { + instance.generateClient("test-bucket", mockClient); + } catch (Exception e) { + assertEquals(NoSuchElementException.class, e.getClass()); + } + + final InOrder inOrder = inOrder(mockClient); + inOrder.verify(mockClient).getBucketLocation(any(Consumer.class)); + inOrder.verify(mockClient).headBucket(any(Consumer.class)); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void testCaching() { + S3Client client = S3Client.create(); + doReturn(client).when(spyInstance).generateClient("test-bucket"); + + final S3Client client1 = spyInstance.getClientForBucketName("test-bucket"); + verify(spyInstance).generateClient("test-bucket"); + assertSame(client1, client); + + S3Client differentClient = S3Client.create(); + assertNotSame(client, differentClient); + + lenient().doReturn(differentClient).when(spyInstance).generateClient("test-bucket"); + final S3Client client2 = spyInstance.getClientForBucketName("test-bucket"); + // same instance because second is cached. + assertSame(client1, client2); + assertSame(client2, client); + assertNotSame(client2, differentClient); + } +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 4504a64ffcc..9b7103a48df 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -10,7 +10,7 @@ object Dependencies { private val aliyunOssV = "3.13.1" private val ammoniteOpsV = "2.4.0" private val apacheHttpClientV = "4.5.13" - private val awsSdkV = "2.17.50" + private val awsSdkV = "2.17.66" // We would like to use the BOM to manage Azure SDK versions, but SBT doesn't support it. // https://github.com/Azure/azure-sdk-for-java/tree/main/sdk/boms/azure-sdk-bom // https://github.com/sbt/sbt/issues/4531 @@ -320,6 +320,9 @@ object Dependencies { "cloudwatchlogs", "s3", "sts", + "ecs", + "ecr", + "ecrpublic", ).map(artifactName => "software.amazon.awssdk" % artifactName % awsSdkV) private val googleCloudDependencies = List( @@ -478,7 +481,7 @@ object Dependencies { val databaseMigrationDependencies: List[ModuleID] = liquibaseDependencies ++ dbmsDependencies - val dockerHashingDependencies: List[ModuleID] = http4sDependencies ++ circeDependencies ++ aliyunCrDependencies + val dockerHashingDependencies: List[ModuleID] = http4sDependencies ++ circeDependencies ++ aliyunCrDependencies ++ awsCloudDependencies val cromwellApiClientDependencies: List[ModuleID] = List( "org.typelevel" %% "cats-effect" % catsEffectV, diff --git a/services/src/main/scala/cromwell/services/metadata/impl/MetadataServiceActor.scala b/services/src/main/scala/cromwell/services/metadata/impl/MetadataServiceActor.scala index 9c17f60895c..f7fc145937d 100644 --- a/services/src/main/scala/cromwell/services/metadata/impl/MetadataServiceActor.scala +++ b/services/src/main/scala/cromwell/services/metadata/impl/MetadataServiceActor.scala @@ -62,7 +62,7 @@ case class MetadataServiceActor(serviceConfig: Config, globalConfig: Config, ser private val metadataReadTimeout: Duration = serviceConfig.getOrElse[Duration]("metadata-read-query-timeout", Duration.Inf) private val metadataReadRowNumberSafetyThreshold: Int = - serviceConfig.getOrElse[Int]("metadata-read-row-number-safety-threshold", 1000000) + serviceConfig.getOrElse[Int]("metadata-read-row-number-safety-threshold", 3000000) private val metadataTableMetricsInterval: Option[FiniteDuration] = serviceConfig.getAs[FiniteDuration]("metadata-table-metrics-interval") diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActor.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActor.scala index 4fdba3b568c..8c5661b4a28 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActor.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAsyncBackendJobExecutionActor.scala @@ -202,10 +202,18 @@ class AwsBatchAsyncBackendJobExecutionActor(override val standardParams: Standar private def inputsFromWomFiles(namePrefix: String, remotePathArray: Seq[WomFile], localPathArray: Seq[WomFile], - jobDescriptor: BackendJobDescriptor): Iterable[AwsBatchInput] = { + jobDescriptor: BackendJobDescriptor, + flag: Boolean): Iterable[AwsBatchInput] = { + (remotePathArray zip localPathArray zipWithIndex) flatMap { case ((remotePath, localPath), index) => - Seq(AwsBatchFileInput(s"$namePrefix-$index", remotePath.valueString, DefaultPathBuilder.get(localPath.valueString), workingDisk)) + var localPathString = localPath.valueString + if (localPathString.startsWith("s3://")){ + localPathString = localPathString.replace("s3://", "") + }else if (localPathString.startsWith("s3:/")) { + localPathString = localPathString.replace("s3:/", "") + } + Seq(AwsBatchFileInput(s"$namePrefix-$index", remotePath.valueString, DefaultPathBuilder.get(localPathString), workingDisk)) } } @@ -237,7 +245,7 @@ class AwsBatchAsyncBackendJobExecutionActor(override val standardParams: Standar val writeFunctionFiles = instantiatedCommand.createdFiles map { f => f.file.value.md5SumShort -> List(f) } toMap val writeFunctionInputs = writeFunctionFiles flatMap { - case (name, files) => inputsFromWomFiles(name, files.map(_.file), files.map(localizationPath), jobDescriptor) + case (name, files) => inputsFromWomFiles(name, files.map(_.file), files.map(localizationPath), jobDescriptor, false) } // Collect all WomFiles from inputs to the call. @@ -257,7 +265,7 @@ class AwsBatchAsyncBackendJobExecutionActor(override val standardParams: Standar } val callInputInputs = callInputFiles flatMap { - case (name, files) => inputsFromWomFiles(name, files, files.map(relativeLocalizationPath), jobDescriptor) + case (name, files) => inputsFromWomFiles(name, files, files.map(relativeLocalizationPath), jobDescriptor, true) } val scriptInput: AwsBatchInput = AwsBatchFileInput( diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala index 26f69c4e79a..dffcb983235 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala @@ -71,7 +71,14 @@ object AwsBatchAttributes { "filesystems.local.auth", "filesystems.s3.auth", "filesystems.s3.caching.duplication-strategy", - "filesystems.local.caching.duplication-strategy" + "filesystems.local.caching.duplication-strategy", + "auth", + "numCreateDefinitionAttempts", + "filesystems.s3.duplication-strategy", + "numSubmitAttempts", + "default-runtime-attributes.scriptBucketName", + "awsBatchRetryAttempts", + "ulimits" ) private val deprecatedAwsBatchKeys: Map[String, String] = Map( diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala index a22cf25f307..de1b4640695 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala @@ -30,8 +30,6 @@ */ package cromwell.backend.impl.aws -import java.security.MessageDigest - import cats.data.ReaderT._ import cats.data.{Kleisli, ReaderT} import cats.effect.{Async, Timer} @@ -53,10 +51,11 @@ import software.amazon.awssdk.services.s3.S3Client import software.amazon.awssdk.services.s3.model.{GetObjectRequest, HeadObjectRequest, NoSuchKeyException, PutObjectRequest} import wdl4s.parser.MemoryUnit +import java.security.MessageDigest import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.higherKinds -import scala.util.{Random, Try} +import scala.util.Try /** * The actual job for submission in AWS batch. `AwsBatchJob` is the primary interface to AWS Batch. It creates the @@ -83,18 +82,12 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL optAwsAuthMode: Option[AwsAuthMode] = None ) { - // values for container environment - val AWS_MAX_ATTEMPTS: String = "AWS_MAX_ATTEMPTS" - val AWS_MAX_ATTEMPTS_DEFAULT_VALUE: String = "14" - val AWS_RETRY_MODE: String = "AWS_RETRY_MODE" - val AWS_RETRY_MODE_DEFAULT_VALUE: String = "adaptive" val Log: Logger = LoggerFactory.getLogger(AwsBatchJob.getClass) //this will be the "folder" that scripts will live in (underneath the script bucket) val scriptKeyPrefix = "scripts/" - // TODO: Auth, endpoint lazy val batchClient: BatchClient = { val builder = BatchClient.builder() configureClient(builder, optAwsAuthMode, configRegion) @@ -117,26 +110,25 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL */ lazy val reconfiguredScript: String = { //this is the location of the aws cli mounted into the container by the ec2 launch template - val s3Cmd = "/usr/local/aws-cli/v2/current/bin/aws s3" + val awsCmd = "/usr/local/aws-cli/v2/current/bin/aws " //internal to the container, therefore not mounted val workDir = "/tmp/scratch" //working in a mount will cause collisions in long running workers val replaced = commandScript.replaceAllLiterally(AwsBatchWorkingDisk.MountPoint.pathAsString, workDir) val insertionPoint = replaced.indexOf("\n", replaced.indexOf("#!")) +1 //just after the new line after the shebang! - /* generate a series of s3 copy statements to copy any s3 files into the container. We randomize the order - so that large scatters don't all attempt to copy the same thing at the same time. */ - val inputCopyCommand = Random.shuffle(inputs.map { + /* generate a series of s3 copy statements to copy any s3 files into the container. */ + val inputCopyCommand = inputs.map { case input: AwsBatchFileInput if input.s3key.startsWith("s3://") && input.s3key.endsWith(".tmp") => //we are localizing a tmp file which may contain workdirectory paths that need to be reconfigured s""" - |$s3Cmd cp --no-progress ${input.s3key} $workDir/${input.local} + |_s3_localize_with_retry ${input.s3key} $workDir/${input.local} |sed -i 's#${AwsBatchWorkingDisk.MountPoint.pathAsString}#$workDir#g' $workDir/${input.local} |""".stripMargin case input: AwsBatchFileInput if input.s3key.startsWith("s3://") => - s"$s3Cmd cp --no-progress ${input.s3key} ${input.mount.mountPoint.pathAsString}/${input.local}" + s"_s3_localize_with_retry ${input.s3key} ${input.mount.mountPoint.pathAsString}/${input.local}" .replaceAllLiterally(AwsBatchWorkingDisk.MountPoint.pathAsString, workDir) case input: AwsBatchFileInput => @@ -147,11 +139,41 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL s"test -e $filePath || echo 'input file: $filePath does not exist' && exit 1" case _ => "" - }.toList).mkString("\n") + }.toList.mkString("\n") // this goes at the start of the script after the #! val preamble = s""" + |export AWS_METADATA_SERVICE_TIMEOUT=10 + |export AWS_METADATA_SERVICE_NUM_ATTEMPTS=10 + | + |function _s3_localize_with_retry() { + | local s3_path=$$1 + | # destination must be the path to a file and not just the directory you want the file in + | local destination=$$2 + | + | for i in {1..5}; + | do + | if [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then + | bucket="$${BASH_REMATCH[1]}" + | key="$${BASH_REMATCH[2]}" + | content_length=$$($awsCmd s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') + | else + | echo "$$s3_path is not an S3 path with a bucket and key. aborting" + | exit 1 + | fi + | $awsCmd s3 cp --no-progress "$$s3_path" "$$destination" && + | [[ $$(LC_ALL=C ls -dn -- "$$destination" | awk '{print $$5; exit}') -eq "$$content_length" ]] && break || + | echo "attempt $$i to copy $$s3_path failed"; + | + | if [ "$$i" -eq 5 ]; then + | echo "failed to copy $$s3_path after $$i attempts. aborting" + | exit 2 + | fi + | sleep $$((7 * "$$i")) + | done + |} + | |{ |set -e |echo '*** LOCALIZING INPUTS ***' @@ -181,24 +203,24 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL */ s""" |touch ${output.name} - |$s3Cmd cp --no-progress ${output.name} ${output.s3key} - |if [ -e $globDirectory ]; then $s3Cmd cp --no-progress $globDirectory $s3GlobOutDirectory --recursive --exclude "cromwell_glob_control_file"; fi + |$awsCmd s3 cp --no-progress ${output.name} ${output.s3key} + |if [ -e $globDirectory ]; then $awsCmd s3 cp --no-progress $globDirectory $s3GlobOutDirectory --recursive --exclude "cromwell_glob_control_file"; fi |""".stripMargin case output: AwsBatchFileOutput if output.s3key.startsWith("s3://") && output.mount.mountPoint.pathAsString == AwsBatchWorkingDisk.MountPoint.pathAsString => //output is on working disk mount s""" - |$s3Cmd cp --no-progress $workDir/${output.local.pathAsString} ${output.s3key} + |$awsCmd s3 cp --no-progress $workDir/${output.local.pathAsString} ${output.s3key} |""".stripMargin case output: AwsBatchFileOutput => //output on a different mount - s"$s3Cmd cp --no-progress ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString} ${output.s3key}" + s"$awsCmd s3 cp --no-progress ${output.mount.mountPoint.pathAsString}/${output.local.pathAsString} ${output.s3key}" case _ => "" }.mkString("\n") + "\n" + s""" - |if [ -f $workDir/${jobPaths.returnCodeFilename} ]; then $s3Cmd cp --no-progress $workDir/${jobPaths.returnCodeFilename} ${jobPaths.callRoot.pathAsString}/${jobPaths.returnCodeFilename} ; fi\n - |if [ -f $stdErr ]; then $s3Cmd cp --no-progress $stdErr ${jobPaths.standardPaths.error.pathAsString}; fi - |if [ -f $stdOut ]; then $s3Cmd cp --no-progress $stdOut ${jobPaths.standardPaths.output.pathAsString}; fi + |if [ -f $workDir/${jobPaths.returnCodeFilename} ]; then $awsCmd s3 cp --no-progress $workDir/${jobPaths.returnCodeFilename} ${jobPaths.callRoot.pathAsString}/${jobPaths.returnCodeFilename} ; fi\n + |if [ -f $stdErr ]; then $awsCmd s3 cp --no-progress $stdErr ${jobPaths.standardPaths.error.pathAsString}; fi + |if [ -f $stdOut ]; then $awsCmd s3 cp --no-progress $stdOut ${jobPaths.standardPaths.output.pathAsString}; fi |""".stripMargin @@ -210,6 +232,10 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL |echo '*** DELOCALIZING OUTPUTS ***' |$outputCopyCommand |echo '*** COMPLETED DELOCALIZATION ***' + |echo '*** EXITING WITH RETURN CODE ***' + |rc=$$(head -n 1 $workDir/${jobPaths.returnCodeFilename}) + |echo $$rc + |exit $$rc |} |""".stripMargin } @@ -219,8 +245,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL } private def generateEnvironmentKVPairs(scriptBucketName: String, scriptKeyPrefix: String, scriptKey: String): List[KeyValuePair] = { - List(buildKVPair(AWS_MAX_ATTEMPTS, AWS_MAX_ATTEMPTS_DEFAULT_VALUE), - buildKVPair(AWS_RETRY_MODE, AWS_RETRY_MODE_DEFAULT_VALUE), + List( buildKVPair("BATCH_FILE_TYPE", "script"), buildKVPair("BATCH_FILE_S3_URL",batch_file_s3_url(scriptBucketName,scriptKeyPrefix,scriptKey))) } @@ -262,18 +287,11 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL .containerOverrides( ContainerOverrides.builder .environment( - generateEnvironmentKVPairs(runtimeAttributes.scriptS3BucketName, scriptKeyPrefix, scriptKey): _* ) .resourceRequirements( - ResourceRequirement.builder() - .`type`(ResourceType.VCPU) - .value(runtimeAttributes.cpu.value.toString) - .build(), - ResourceRequirement.builder() - .`type`(ResourceType.MEMORY) - .value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString) - .build(), + ResourceRequirement.builder().`type`(ResourceType.VCPU).value(runtimeAttributes.cpu.##.toString).build(), + ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build() ) .build() ) @@ -395,16 +413,19 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL // See: // // http://aws-java-sdk-javadoc.s3-website-us-west-2.amazonaws.com/latest/software/amazon/awssdk/services/batch/model/RegisterJobDefinitionRequest.Builder.html - val definitionRequest = RegisterJobDefinitionRequest.builder + var definitionRequest = RegisterJobDefinitionRequest.builder .containerProperties(jobDefinition.containerProperties) .jobDefinitionName(jobDefinitionName) // See https://stackoverflow.com/questions/24349517/scala-method-named-type .`type`(JobDefinitionType.CONTAINER) - .build + + if (jobDefinitionContext.runtimeAttributes.awsBatchRetryAttempts != 0){ + definitionRequest = definitionRequest.retryStrategy(jobDefinition.retryStrategy) + } Log.debug(s"Submitting definition request: $definitionRequest") - val response: RegisterJobDefinitionResponse = batchClient.registerJobDefinition(definitionRequest) + val response: RegisterJobDefinitionResponse = batchClient.registerJobDefinition(definitionRequest.build) Log.info(s"Definition created: $response") response.jobDefinitionArn() } @@ -447,7 +468,6 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL } yield runStatus def detail(jobId: String): JobDetail = { - //TODO: This client call should be wrapped in a cats Effect val describeJobsResponse = batchClient.describeJobs(DescribeJobsRequest.builder.jobs(jobId).build) val jobDetail = describeJobsResponse.jobs.asScala.headOption. diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala index ae356b89772..b8ac5cb229a 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws import scala.collection.mutable.ListBuffer import cromwell.backend.BackendJobDescriptor import cromwell.backend.io.JobPaths -import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, Volume} +import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, RetryStrategy, Ulimit, Volume} import cromwell.backend.impl.aws.io.AwsBatchVolume import scala.collection.JavaConverters._ @@ -60,12 +60,14 @@ import wdl4s.parser.MemoryUnit */ sealed trait AwsBatchJobDefinition { def containerProperties: ContainerProperties + def retryStrategy: RetryStrategy def name: String override def toString: String = { new ToStringBuilder(this, ToStringStyle.JSON_STYLE) .append("name", name) .append("containerProperties", containerProperties) + .append("retryStrategy", retryStrategy) .build } } @@ -76,23 +78,13 @@ trait AwsBatchJobDefinitionBuilder { /** Gets a builder, seeded with appropriate portions of the container properties * - * @param dockerImage docker image with which to run - * @return ContainerProperties builder ready for modification + * @param context AwsBatchJobDefinitionContext with all the runtime attributes + * @return ContainerProperties builder ready for modification and name * */ - def builder(dockerImage: String): ContainerProperties.Builder = - ContainerProperties.builder().image(dockerImage) - - - def buildResources(builder: ContainerProperties.Builder, - context: AwsBatchJobDefinitionContext): (ContainerProperties.Builder, String) = { - // The initial buffer should only contain one item - the hostpath of the - // local disk mount point, which will be needed by the docker container - // that copies data around - - val environment = List.empty[KeyValuePair] - - + def containerPropertiesBuilder(context: AwsBatchJobDefinitionContext): (ContainerProperties.Builder, String) = { + + def buildVolumes(disks: Seq[AwsBatchVolume]): List[Volume] = { //all the configured disks plus the fetch and run volume and the aws-cli volume @@ -109,6 +101,7 @@ trait AwsBatchJobDefinitionBuilder { ) } + def buildMountPoints(disks: Seq[AwsBatchVolume]): List[MountPoint] = { //all the configured disks plus the fetch and run mount point and the AWS cli mount point @@ -128,53 +121,65 @@ trait AwsBatchJobDefinitionBuilder { ) } - def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair]): String = { - val str = s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}" - val sha1 = MessageDigest.getInstance("SHA-1") - .digest( str.getBytes("UTF-8") ) - .map("%02x".format(_)).mkString + def buildUlimits(ulimits: Seq[Map[String, String]]): List[Ulimit] = { - val prefix = s"cromwell_$imageName".slice(0,88) // will be joined to a 40 character SHA1 for total length of 128 - - sanitize(prefix + sha1) + ulimits.filter(_.nonEmpty).map(u => + Ulimit.builder() + .name(u("name")) + .softLimit(u("softLimit").toInt) + .hardLimit(u("hardLimit").toInt) + .build() + ).toList } + def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair], ulimits: List[Ulimit]): String = { + s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}:${ulimits.map(_.toString).mkString(",")}" + } + + + val environment = List.empty[KeyValuePair] val cmdName = context.runtimeAttributes.fileSystem match { - case AWSBatchStorageSystems.s3 => "/var/scratch/fetch_and_run.sh" - case _ => context.commandText + case AWSBatchStorageSystems.s3 => "/var/scratch/fetch_and_run.sh" + case _ => context.commandText } val packedCommand = packCommand("/bin/bash", "-c", cmdName) val volumes = buildVolumes( context.runtimeAttributes.disks ) val mountPoints = buildMountPoints( context.runtimeAttributes.disks) - val jobDefinitionName = buildName( + val ulimits = buildUlimits( context.runtimeAttributes.ulimits) + val containerPropsName = buildName( context.runtimeAttributes.dockerImage, packedCommand.mkString(","), volumes, mountPoints, - environment + environment, + ulimits ) - (builder - .command(packedCommand.asJava) + (ContainerProperties.builder() + .image(context.runtimeAttributes.dockerImage) + .command(packedCommand.asJava) .resourceRequirements( - ResourceRequirement.builder() - .`type`(ResourceType.MEMORY) - .value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString) - .build(), - ResourceRequirement.builder() - .`type`(ResourceType.VCPU) - .value(context.runtimeAttributes.cpu.value.toString) - .build(), + ResourceRequirement.builder().`type`(ResourceType.VCPU).value(context.runtimeAttributes.cpu.##.toString).build(), + ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build() ) - .volumes( volumes.asJava) - .mountPoints( mountPoints.asJava) - .environment(environment.asJava), + .volumes(volumes.asJava) + .mountPoints(mountPoints.asJava) + .environment(environment.asJava) + .ulimits(ulimits.asJava), + containerPropsName) + } - jobDefinitionName) + def retryStrategyBuilder(context: AwsBatchJobDefinitionContext): (RetryStrategy.Builder, String) = { + // We can add here the 'evaluateOnExit' statement + + (RetryStrategy.builder() + .attempts(context.runtimeAttributes.awsBatchRetryAttempts), + context.runtimeAttributes.awsBatchRetryAttempts.toString) } + private def packCommand(shell: String, options: String, mainCommand: String): Seq[String] = { val rc = new ListBuffer[String]() val lim = 1024 @@ -195,15 +200,29 @@ trait AwsBatchJobDefinitionBuilder { object StandardAwsBatchJobDefinitionBuilder extends AwsBatchJobDefinitionBuilder { def build(context: AwsBatchJobDefinitionContext): AwsBatchJobDefinition = { - //instantiate a builder with the name of the docker image - val builderInst = builder(context.runtimeAttributes.dockerImage) - val (b, name) = buildResources(builderInst, context) + + val (containerPropsInst, containerPropsName) = containerPropertiesBuilder(context) + val (retryStrategyInst, retryStrategyName) = retryStrategyBuilder(context) + + val name = buildName(context.runtimeAttributes.dockerImage, containerPropsName, retryStrategyName) - new StandardAwsBatchJobDefinitionBuilder(b.build, name) + new StandardAwsBatchJobDefinitionBuilder(containerPropsInst.build, retryStrategyInst.build, name) } + + def buildName(imageName: String, containerPropsName: String, retryStrategyName: String): String = { + val str = s"$imageName:$containerPropsName:$retryStrategyName" + + val sha1 = MessageDigest.getInstance("SHA-1") + .digest( str.getBytes("UTF-8") ) + .map("%02x".format(_)).mkString + + val prefix = s"cromwell_${imageName}_".slice(0,88) // will be joined to a 40 character SHA1 for total length of 128 + + sanitize(prefix + sha1) + } } -case class StandardAwsBatchJobDefinitionBuilder private(containerProperties: ContainerProperties, name: String) extends AwsBatchJobDefinition +case class StandardAwsBatchJobDefinitionBuilder private(containerProperties: ContainerProperties, retryStrategy: RetryStrategy, name: String) extends AwsBatchJobDefinition object AwsBatchJobDefinitionContext diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala index c6fc2a5f51f..8296eefd42a 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala @@ -60,6 +60,8 @@ import scala.util.matching.Regex * @param noAddress is there no address * @param scriptS3BucketName the s3 bucket where the execution command or script will be written and, from there, fetched into the container and executed * @param fileSystem the filesystem type, default is "s3" + * @param awsBatchRetryAttempts number of attempts that AWS Batch will retry the task if it fails + * @param ulimits ulimit values to be passed to the container */ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive, zones: Vector[String], @@ -71,7 +73,9 @@ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive, continueOnReturnCode: ContinueOnReturnCode, noAddress: Boolean, scriptS3BucketName: String, - fileSystem:String= "s3") + awsBatchRetryAttempts: Int, + ulimits: Vector[Map[String, String]], + fileSystem: String= "s3") object AwsBatchRuntimeAttributes { @@ -79,6 +83,8 @@ object AwsBatchRuntimeAttributes { val scriptS3BucketKey = "scriptBucketName" + val awsBatchRetryAttemptsKey = "awsBatchRetryAttempts" + val ZonesKey = "zones" private val ZonesDefaultValue = WomString("us-east-1a") @@ -92,6 +98,9 @@ object AwsBatchRuntimeAttributes { private val MemoryDefaultValue = "2 GB" + val UlimitsKey = "ulimits" + private val UlimitsDefaultValue = WomArray(WomArrayType(WomMapType(WomStringType,WomStringType)), Vector(WomMap(Map.empty[WomValue, WomValue]))) + private def cpuValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instance .withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin) @@ -134,6 +143,14 @@ object AwsBatchRuntimeAttributes { QueueArnValidation.withDefault(QueueArnValidation.configDefaultWomValue(runtimeConfig) getOrElse (throw new RuntimeException("queueArn is required"))) + private def awsBatchRetryAttemptsValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int] = { + AwsBatchRetryAttemptsValidation(awsBatchRetryAttemptsKey).withDefault(AwsBatchRetryAttemptsValidation(awsBatchRetryAttemptsKey) + .configDefaultWomValue(runtimeConfig).getOrElse(WomInteger(0))) + } + + private def ulimitsValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Vector[Map[String, String]]] = + UlimitsValidation.withDefault(UlimitsValidation.configDefaultWomValue(runtimeConfig) getOrElse UlimitsDefaultValue) + def runtimeAttributesBuilder(configuration: AwsBatchConfiguration): StandardValidatedRuntimeAttributesBuilder = { val runtimeConfig = configuration.runtimeConfig def validationsS3backend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation( @@ -146,7 +163,9 @@ object AwsBatchRuntimeAttributes { noAddressValidation(runtimeConfig), dockerValidation, queueArnValidation(runtimeConfig), - scriptS3BucketNameValidation(runtimeConfig) + scriptS3BucketNameValidation(runtimeConfig), + awsBatchRetryAttemptsValidation(runtimeConfig), + ulimitsValidation(runtimeConfig), ) def validationsLocalBackend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation( cpuValidation(runtimeConfig), @@ -181,6 +200,8 @@ object AwsBatchRuntimeAttributes { case AWSBatchStorageSystems.s3 => RuntimeAttributesValidation.extract(scriptS3BucketNameValidation(runtimeAttrsConfig) , validatedRuntimeAttributes) case _ => "" } + val awsBatchRetryAttempts: Int = RuntimeAttributesValidation.extract(awsBatchRetryAttemptsValidation(runtimeAttrsConfig), validatedRuntimeAttributes) + val ulimits: Vector[Map[String, String]] = RuntimeAttributesValidation.extract(ulimitsValidation(runtimeAttrsConfig), validatedRuntimeAttributes) new AwsBatchRuntimeAttributes( @@ -194,6 +215,8 @@ object AwsBatchRuntimeAttributes { continueOnReturnCode, noAddress, scriptS3BucketName, + awsBatchRetryAttempts, + ulimits, fileSystem ) } @@ -372,3 +395,94 @@ object DisksValidation extends RuntimeAttributesValidation[Seq[AwsBatchVolume]] override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be a comma separated String or Array[String]" } + +object AwsBatchRetryAttemptsValidation { + def apply(key: String): AwsBatchRetryAttemptsValidation = new AwsBatchRetryAttemptsValidation(key) +} + +class AwsBatchRetryAttemptsValidation(key: String) extends IntRuntimeAttributesValidation(key) { + override protected def validateValue: PartialFunction[WomValue, ErrorOr[Int]] = { + case womValue if WomIntegerType.coerceRawValue(womValue).isSuccess => + WomIntegerType.coerceRawValue(womValue).get match { + case WomInteger(value) => + if (value.toInt < 0) + s"Expecting $key runtime attribute value greater than or equal to 0".invalidNel + else if (value.toInt > 10) + s"Expecting $key runtime attribute value lower than or equal to 10".invalidNel + else + value.toInt.validNel + } + } + + override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be an Integer" +} + + +object UlimitsValidation + extends RuntimeAttributesValidation[Vector[Map[String, String]]] { + override def key: String = AwsBatchRuntimeAttributes.UlimitsKey + + override def coercion: Traversable[WomType] = + Set(WomStringType, WomArrayType(WomMapType(WomStringType, WomStringType))) + + var accepted_keys = Set("name", "softLimit", "hardLimit") + + override protected def validateValue + : PartialFunction[WomValue, ErrorOr[Vector[Map[String, String]]]] = { + case WomArray(womType, value) + if womType.memberType == WomMapType(WomStringType, WomStringType) => + check_maps(value.toVector) + case WomMap(_, _) => "!!! ERROR1".invalidNel + + } + + private def check_maps( + maps: Vector[WomValue] + ): ErrorOr[Vector[Map[String, String]]] = { + val entryNels: Vector[ErrorOr[Map[String, String]]] = maps.map { + case WomMap(_, value) => check_keys(value) + case _ => "!!! ERROR2".invalidNel + } + val sequenced: ErrorOr[Vector[Map[String, String]]] = sequenceNels( + entryNels + ) + sequenced + } + + private def check_keys( + dict: Map[WomValue, WomValue] + ): ErrorOr[Map[String, String]] = { + val map_keys = dict.keySet.map(_.valueString).toSet + val unrecognizedKeys = + accepted_keys.diff(map_keys) union map_keys.diff(accepted_keys) + + if (!dict.nonEmpty){ + Map.empty[String, String].validNel + }else if (unrecognizedKeys.nonEmpty) { + s"Invalid keys in $key runtime attribute. Refer to 'ulimits' section on https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#containerProperties".invalidNel + } else { + dict + .collect { case (WomString(k), WomString(v)) => + (k, v) + // case _ => "!!! ERROR3".invalidNel + } + .toMap + .validNel + } + } + + private def sequenceNels( + nels: Vector[ErrorOr[Map[String, String]]] + ): ErrorOr[Vector[Map[String, String]]] = { + val emptyNel: ErrorOr[Vector[Map[String, String]]] = + Vector.empty[Map[String, String]].validNel + val seqNel: ErrorOr[Vector[Map[String, String]]] = + nels.foldLeft(emptyNel) { (acc, v) => + (acc, v) mapN { (a, v) => a :+ v } + } + seqNel + } + + override protected def missingValueMessage: String = + s"Expecting $key runtime attribute to be an Array[Map[String, String]]" +} \ No newline at end of file diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala index 5037fc21051..12e933ab959 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala @@ -34,7 +34,9 @@ package cromwell.backend.impl.aws import common.collections.EnhancedCollections._ import cromwell.backend.{BackendJobDescriptorKey, BackendWorkflowDescriptor} import cromwell.backend.BackendSpec._ +import cromwell.backend.impl.aws.io.AwsBatchWorkingDisk import cromwell.backend.validation.ContinueOnReturnCodeFlag +import cromwell.core.path.DefaultPathBuilder import cromwell.core.TestKitSuite import cromwell.util.SampleWdl import eu.timepit.refined.api.Refined @@ -45,7 +47,7 @@ import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers import org.specs2.mock.Mockito import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider -import software.amazon.awssdk.services.batch.model.KeyValuePair +import software.amazon.awssdk.services.batch.model.{ContainerDetail, JobDetail, KeyValuePair} import spray.json.{JsObject, JsString} import wdl4s.parser.MemoryUnit import wom.format.MemorySize @@ -56,7 +58,7 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi System.setProperty("aws.region", "us-east-1") - val script = """ + val script: String = """ |tmpDir=mkdir -p "/cromwell-aws/cromwell-execution/wf_hello/2422ea26-2578-48b0-86e9-50cbdda7d70a/call-hello/tmp.39397e83" && echo "/cromwell-aws/cromwell-execution/wf_hello/2422ea26-2578-48b0-86e9-50cbdda7d70a/call-hello/tmp.39397e83" |chmod 777 "$tmpDir" |export _JAVA_OPTIONS=-Djava.io.tmpdir="$tmpDir" @@ -84,7 +86,7 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi | | |) - |mv /cromwell_root/hello-rc.txt.tmp /cromwell_root/hello-rc.txt""" + |mv /cromwell_root/hello-rc.txt.tmp /cromwell_root/hello-rc.txt""".stripMargin val workFlowDescriptor: BackendWorkflowDescriptor = buildWdlWorkflowDescriptor( SampleWdl.HelloWorld.workflowSource(), @@ -100,6 +102,8 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi val call: CommandCallNode = workFlowDescriptor.callable.taskCallNodes.head val jobKey: BackendJobDescriptorKey = BackendJobDescriptorKey(call, None, 1) val jobPaths: AwsBatchJobPaths = AwsBatchJobPaths(workflowPaths, jobKey) + val s3Inputs: Set[AwsBatchInput] = Set(AwsBatchFileInput("foo", "s3://bucket/foo", DefaultPathBuilder.get("foo"), AwsBatchWorkingDisk())) + val s3Outputs: Set[AwsBatchFileOutput] = Set(AwsBatchFileOutput("baa", "s3://bucket/somewhere/baa", DefaultPathBuilder.get("baa"), AwsBatchWorkingDisk())) val cpu: Int Refined Positive = 2 val runtimeAttributes: AwsBatchRuntimeAttributes = new AwsBatchRuntimeAttributes( @@ -113,8 +117,13 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi continueOnReturnCode = ContinueOnReturnCodeFlag(false), noAddress = false, scriptS3BucketName = "script-bucket", + awsBatchRetryAttempts = 1, + ulimits = Vector(Map.empty[String, String]), fileSystem = "s3") + val containerDetail: ContainerDetail = ContainerDetail.builder().exitCode(0).build() + val jobDetail: JobDetail = JobDetail.builder().container(containerDetail).build + private def generateBasicJob: AwsBatchJob = { val job = AwsBatchJob(null, runtimeAttributes, "commandLine", script, "/cromwell_root/hello-rc.txt", "/cromwell_root/hello-stdout.log", "/cromwell_root/hello-stderr.log", @@ -129,29 +138,22 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi jobPaths, Seq.empty[AwsBatchParameter], None) job } + private def generateJobWithS3InOut: AwsBatchJob = { + val job = AwsBatchJob(null, runtimeAttributes, "commandLine", script, + "/cromwell_root/hello-rc.txt", "/cromwell_root/hello-stdout.log", "/cromwell_root/hello-stderr.log", + s3Inputs, s3Outputs, + jobPaths, Seq.empty[AwsBatchParameter], None) + job + } // TESTS BEGIN HERE behavior of "AwsBatchJob" - - it should "have correctly named AWS constants" in { - - val job: AwsBatchJob = generateBasicJob - - job.AWS_RETRY_MODE should be ("AWS_RETRY_MODE") - job.AWS_RETRY_MODE_DEFAULT_VALUE should be ("adaptive") - job.AWS_MAX_ATTEMPTS should be ("AWS_MAX_ATTEMPTS") - job.AWS_MAX_ATTEMPTS_DEFAULT_VALUE should be ("14") - } - it should "generate appropriate KV pairs for the container environment for S3" in { val job = generateBasicJob val generateEnvironmentKVPairs = PrivateMethod[List[KeyValuePair]]('generateEnvironmentKVPairs) // testing a private method see https://www.scalatest.org/user_guide/using_PrivateMethodTester val kvPairs = job invokePrivate generateEnvironmentKVPairs("script-bucket", "prefix-", "key") - - kvPairs should contain (buildKVPair(job.AWS_MAX_ATTEMPTS, job.AWS_MAX_ATTEMPTS_DEFAULT_VALUE)) - kvPairs should contain (buildKVPair(job.AWS_RETRY_MODE, "adaptive")) kvPairs should contain (buildKVPair("BATCH_FILE_TYPE", "script")) kvPairs should contain (buildKVPair("BATCH_FILE_S3_URL", "s3://script-bucket/prefix-key")) } @@ -162,10 +164,118 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi // testing a private method see https://www.scalatest.org/user_guide/using_PrivateMethodTester val kvPairs = job invokePrivate generateEnvironmentKVPairs("script-bucket", "prefix-", "key") - - kvPairs should contain (buildKVPair(job.AWS_MAX_ATTEMPTS, job.AWS_MAX_ATTEMPTS_DEFAULT_VALUE)) - kvPairs should contain (buildKVPair(job.AWS_RETRY_MODE, "adaptive")) kvPairs should contain (buildKVPair("BATCH_FILE_TYPE", "script")) kvPairs should contain (buildKVPair("BATCH_FILE_S3_URL", "")) } + + it should "contain expected command script in reconfigured script" in { + val job = generateBasicJob + job.reconfiguredScript should include (script.replace("/cromwell_root", "/tmp/scratch")) + } + + it should "add metadata environment variables to reconfigured script" in { + val job = generateJobWithS3InOut + job.reconfiguredScript should include ("export AWS_METADATA_SERVICE_TIMEOUT=10\n") + job.reconfiguredScript should include ("export AWS_METADATA_SERVICE_NUM_ATTEMPTS=10\n") + } + + it should "add s3 localize with retry function to reconfigured script" in { + val job = generateBasicJob + val retryFunctionText = s""" + |function _s3_localize_with_retry() { + | local s3_path=$$1 + | # destination must be the path to a file and not just the directory you want the file in + | local destination=$$2 + | + | for i in {1..5}; + | do + | if [[ $$s3_path =~ s3://([^/]+)/(.+) ]]; then + | bucket="$${BASH_REMATCH[1]}" + | key="$${BASH_REMATCH[2]}" + | content_length=$$(/usr/local/aws-cli/v2/current/bin/aws s3api head-object --bucket "$$bucket" --key "$$key" --query 'ContentLength') + | else + | echo "$$s3_path is not an S3 path with a bucket and key. aborting" + | exit 1 + | fi + | /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress "$$s3_path" "$$destination" && + | [[ $$(LC_ALL=C ls -dn -- "$$destination" | awk '{print $$5; exit}') -eq "$$content_length" ]] && break || + | echo "attempt $$i to copy $$s3_path failed"; + | + | if [ "$$i" -eq 5 ]; then + | echo "failed to copy $$s3_path after $$i attempts. aborting" + | exit 2 + | fi + | sleep $$((7 * "$$i")) + | done + |} + |""".stripMargin + + job.reconfiguredScript should include (retryFunctionText) + } + + it should "generate postscript with output copy command in reconfigured script" in { + val job = generateJobWithS3InOut + val postscript = + s""" + |{ + |set -e + |echo '*** DELOCALIZING OUTPUTS ***' + | + |/usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress /tmp/scratch/baa s3://bucket/somewhere/baa + | + | + |if [ -f /tmp/scratch/hello-rc.txt ]; then /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress /tmp/scratch/hello-rc.txt ${job.jobPaths.returnCode} ; fi + | + |if [ -f /tmp/scratch/hello-stderr.log ]; then /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress /tmp/scratch/hello-stderr.log ${job.jobPaths.standardPaths.error}; fi + |if [ -f /tmp/scratch/hello-stdout.log ]; then /usr/local/aws-cli/v2/current/bin/aws s3 cp --no-progress /tmp/scratch/hello-stdout.log ${job.jobPaths.standardPaths.output}; fi + | + |echo '*** COMPLETED DELOCALIZATION ***' + |echo '*** EXITING WITH RETURN CODE ***' + |rc=$$(head -n 1 /tmp/scratch/hello-rc.txt) + |echo $$rc + |exit $$rc + |} + |""".stripMargin + job.reconfiguredScript should include (postscript) + } + + it should "generate preamble with input copy command in reconfigured script" in { + val job = generateJobWithS3InOut + val preamble = + s""" + |{ + |set -e + |echo '*** LOCALIZING INPUTS ***' + |if [ ! -d /tmp/scratch ]; then mkdir /tmp/scratch && chmod 777 /tmp/scratch; fi + |cd /tmp/scratch + |_s3_localize_with_retry s3://bucket/foo /tmp/scratch/foo + |echo '*** COMPLETED LOCALIZATION ***' + |set +e + |} + |""".stripMargin + + job.reconfiguredScript should include (preamble) + } + + it should "contain AWS Service clients" in { + val job = generateBasicJob + job.batchClient should not be null + job.s3Client should not be null + job.cloudWatchLogsClient should not be null + } + + it should "have correct script prefix" in { + val job = generateBasicJob + job.scriptKeyPrefix should equal("scripts/") + } + + it should "return correct RC code given Batch Job Detail" in { + val containerDetail: ContainerDetail = ContainerDetail.builder().exitCode(0).build + val jobDetail: JobDetail = JobDetail.builder().container(containerDetail).build + val job = generateBasicJob + job.rc(jobDetail) should be (0) + } + + + } diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala index 09e1ee94351..f8c009f95ed 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala @@ -65,7 +65,9 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout false, ContinueOnReturnCodeSet(Set(0)), false, - "my-stuff") + "my-stuff", + 1, + Vector(Map.empty[String, String])) val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"), @@ -76,6 +78,8 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout ContinueOnReturnCodeSet(Set(0)), false, "", + 1, + Vector(Map.empty[String, String]), "local") "AwsBatchRuntimeAttributes" should { @@ -339,6 +343,33 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout val expectedRuntimeAttributes = expectedDefaults.copy(cpu = refineMV[Positive](4)) assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedRuntimeAttributes, workflowOptions) } + + "validate a valid awsBatchRetryAttempts entry" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "awsBatchRetryAttempts" -> WomInteger(9), "scriptBucketName" -> WomString("my-stuff")) + val expectedRuntimeAttributes = expectedDefaults.copy(awsBatchRetryAttempts = 9) + assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedRuntimeAttributes) + } + + "fail to validate with -1 as awsBatchRetryAttempts" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "awsBatchRetryAttempts" -> WomInteger(-1), "scriptBucketName" -> WomString("my-stuff")) + assertAwsBatchRuntimeAttributesFailedCreation(runtimeAttributes, "Expecting awsBatchRetryAttempts runtime attribute value greater than or equal to 0") + } + + "fail to validate with 12 as awsBatchRetryAttempts" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "awsBatchRetryAttempts" -> WomInteger(12), "scriptBucketName" -> WomString("my-stuff")) + assertAwsBatchRuntimeAttributesFailedCreation(runtimeAttributes, "Expecting awsBatchRetryAttempts runtime attribute value lower than or equal to 10") + } + + "fail to validate with a string as awsBatchRetryAttempts" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "awsBatchRetryAttempts" -> WomString("test"), "scriptBucketName" -> WomString("my-stuff")) + assertAwsBatchRuntimeAttributesFailedCreation(runtimeAttributes, "Expecting awsBatchRetryAttempts runtime attribute to be an Integer") + } + + "validate zero as awsBatchRetryAttempts entry" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "awsBatchRetryAttempts" -> WomInteger(0), "scriptBucketName" -> WomString("my-stuff")) + val expectedRuntimeAttributes = expectedDefaults.copy(awsBatchRetryAttempts = 0) + assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes, expectedRuntimeAttributes) + } } private def assertAwsBatchRuntimeAttributesSuccessfulCreation(runtimeAttributes: Map[String, WomValue], diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala index 38545c7e472..682714b225c 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala @@ -61,6 +61,7 @@ object AwsBatchTestConfig { | zones:["us-east-1a", "us-east-1b"] | queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue" | scriptBucketName: "my-bucket" + | awsBatchRetryAttempts: 1 |} | |""".stripMargin @@ -140,6 +141,7 @@ object AwsBatchTestConfigForLocalFS { | zones:["us-east-1a", "us-east-1b"] | queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue" | scriptBucketName: "" + | awsBatchRetryAttempts: 1 |} | |""".stripMargin diff --git a/wdl/transforms/new-base/src/main/scala/wdl/transforms/base/linking/expression/values/EngineFunctionEvaluators.scala b/wdl/transforms/new-base/src/main/scala/wdl/transforms/base/linking/expression/values/EngineFunctionEvaluators.scala index 551dde8086f..2ec759c946c 100644 --- a/wdl/transforms/new-base/src/main/scala/wdl/transforms/base/linking/expression/values/EngineFunctionEvaluators.scala +++ b/wdl/transforms/new-base/src/main/scala/wdl/transforms/base/linking/expression/values/EngineFunctionEvaluators.scala @@ -51,7 +51,7 @@ object EngineFunctionEvaluators { EvaluatedValue(WomSingleFile(ioFunctionSet.pathFunctions.stderr), Seq.empty).validNel } - private val ReadWaitTimeout = 60.seconds + private val ReadWaitTimeout = 300.seconds private def readFile(fileToRead: WomSingleFile, ioFunctionSet: IoFunctionSet, sizeLimit: Int) = { Try(Await.result(ioFunctionSet.readFile(fileToRead.value, Option(sizeLimit), failOnOverflow = true), ReadWaitTimeout)) }