diff --git a/backend/src/main/scala/cromwell/backend/backend.scala b/backend/src/main/scala/cromwell/backend/backend.scala index a8bb4baaf2f..3c5c33ab992 100644 --- a/backend/src/main/scala/cromwell/backend/backend.scala +++ b/backend/src/main/scala/cromwell/backend/backend.scala @@ -136,6 +136,7 @@ object CommonBackendConfigurationAttributes { "default-runtime-attributes.zones", "default-runtime-attributes.continueOnReturnCode", "default-runtime-attributes.cpu", + "default-runtime-attributes.gpuCount", "default-runtime-attributes.noAddress", "default-runtime-attributes.docker", "default-runtime-attributes.queueArn", diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala index 10e2dddfb98..38f026eadcf 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchAttributes.scala @@ -90,6 +90,7 @@ object AwsBatchAttributes { "awsBatchRetryAttempts", "awsBatchEvaluateOnExit", "ulimits", + "gpuCount", "efsDelocalize", "efsMakeMD5", "tagResources", diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala index 51c5b8080f6..4cfa08e82a5 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala @@ -571,6 +571,17 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL Log.debug(s"Submitting taskId: $taskId, job definition : $definitionArn, script: $batch_script") Log.info(s"Submitting taskId: $rootworkflowId::$taskId, script: $batch_script") + //provide job environment variables, vcpu and memory + var resourceRequirements: Seq[ResourceRequirement] = Seq( + ResourceRequirement.builder().`type`(ResourceType.VCPU).value(runtimeAttributes.cpu.##.toString).build(), + ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build() + ) + + if (runtimeAttributes.gpuCount > 0) { + val gpuRequirement = ResourceRequirement.builder().`type`(ResourceType.GPU).value(runtimeAttributes.gpuCount.toString) + resourceRequirements = resourceRequirements :+ gpuRequirement.build() + } + // prepare the job request var submitJobRequest = SubmitJobRequest.builder() .jobName(sanitize(jobDescriptor.taskCall.fullyQualifiedName)) @@ -581,10 +592,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL .environment( generateEnvironmentKVPairs(runtimeAttributes.scriptS3BucketName, scriptKeyPrefix, scriptKey): _* ) - .resourceRequirements( - ResourceRequirement.builder().`type`(ResourceType.VCPU).value(runtimeAttributes.cpu.##.toString).build(), - ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build() - ) + .resourceRequirements(resourceRequirements.asJava) .build() ) .jobQueue(runtimeAttributes.queueArn) diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala index 30148acef87..f55d3a933f7 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala @@ -182,6 +182,8 @@ trait AwsBatchJobDefinitionBuilder { tagResources ) + // To reuse job definition for gpu and gpu-runs, we will create a job definition that does not gpu requirements + // since aws batch does not allow you to set gpu as 0 when you dont need it. you will always need cpu and memory (ContainerProperties.builder() .image(context.runtimeAttributes.dockerImage) .command(packedCommand.asJava) diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala index fc854c1039f..456355e412c 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala @@ -48,6 +48,7 @@ import com.typesafe.config.{ConfigException, ConfigValueFactory} import scala.util.matching.Regex import org.slf4j.{Logger, LoggerFactory} +import wom.RuntimeAttributesKeys.GpuKey import scala.util.{Failure, Success, Try} import scala.jdk.CollectionConverters._ @@ -56,6 +57,7 @@ import scala.jdk.CollectionConverters._ /** * Attributes that are provided to the job at runtime * @param cpu number of vCPU + * @param gpuCount number of gpu * @param zones the aws availability zones to run in * @param memory memory to allocate * @param disks a sequence of disk volumes @@ -74,6 +76,7 @@ import scala.jdk.CollectionConverters._ * @param tagResources should we tag resources */ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive, + gpuCount: Int, zones: Vector[String], memory: MemorySize, disks: Seq[AwsBatchVolume], @@ -125,6 +128,10 @@ object AwsBatchRuntimeAttributes { private def cpuValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instance .withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin) + private def gpuCountValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int] = { + GpuCountValidation(GpuKey).withDefault(GpuCountValidation(GpuKey).configDefaultWomValue(runtimeConfig).getOrElse(WomInteger(0))) + } + private def cpuMinValidation(runtimeConfig: Option[Config]):RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instanceMin .withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin) @@ -222,6 +229,7 @@ object AwsBatchRuntimeAttributes { def validationsS3backend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation( cpuValidation(runtimeConfig), cpuMinValidation(runtimeConfig), + gpuCountValidation(runtimeConfig), disksValidation(runtimeConfig), zonesValidation(runtimeConfig), memoryValidation(runtimeConfig), @@ -240,6 +248,7 @@ object AwsBatchRuntimeAttributes { def validationsLocalBackend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation( cpuValidation(runtimeConfig), cpuMinValidation(runtimeConfig), + gpuCountValidation(runtimeConfig), disksValidation(runtimeConfig), zonesValidation(runtimeConfig), memoryValidation(runtimeConfig), @@ -264,6 +273,7 @@ object AwsBatchRuntimeAttributes { def apply(validatedRuntimeAttributes: ValidatedRuntimeAttributes, runtimeAttrsConfig: Option[Config], fileSystem:String): AwsBatchRuntimeAttributes = { val cpu: Int Refined Positive = RuntimeAttributesValidation.extract(cpuValidation(runtimeAttrsConfig), validatedRuntimeAttributes) + val gpuCount: Int = RuntimeAttributesValidation.extract(gpuCountValidation(runtimeAttrsConfig), validatedRuntimeAttributes) val zones: Vector[String] = RuntimeAttributesValidation.extract(ZonesValidation, validatedRuntimeAttributes) val memory: MemorySize = RuntimeAttributesValidation.extract(memoryValidation(runtimeAttrsConfig), validatedRuntimeAttributes) val disks: Seq[AwsBatchVolume] = RuntimeAttributesValidation.extract(disksValidation(runtimeAttrsConfig), validatedRuntimeAttributes) @@ -285,6 +295,7 @@ object AwsBatchRuntimeAttributes { val tagResources: Boolean = RuntimeAttributesValidation.extract(awsBatchtagResourcesValidation(runtimeAttrsConfig),validatedRuntimeAttributes) new AwsBatchRuntimeAttributes( cpu, + gpuCount, zones, memory, disks, @@ -479,6 +490,23 @@ object DisksValidation extends RuntimeAttributesValidation[Seq[AwsBatchVolume]] s"Expecting $key runtime attribute to be a comma separated String or Array[String]" } +object GpuCountValidation { + def apply(key: String): GpuCountValidation = new GpuCountValidation(key) +} + +class GpuCountValidation(key: String) extends IntRuntimeAttributesValidation(key) { + override protected def validateValue: PartialFunction[WomValue, ErrorOr[Int]] = { + case womValue if WomIntegerType.coerceRawValue(womValue).isSuccess => + WomIntegerType.coerceRawValue(womValue).get match { + case WomInteger(value) => + if (value.toInt < 0) + s"Expecting $key runtime attribute value greater than or equal to 0".invalidNel + else + value.toInt.validNel + } + } +} + object AwsBatchRetryAttemptsValidation { def apply(key: String): AwsBatchRetryAttemptsValidation = new AwsBatchRetryAttemptsValidation(key) } diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md index 2c673503c7b..4b1fd62da2d 100644 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/README.md @@ -141,6 +141,34 @@ Parameter description: - Type: Integer - Required: Yes, when `ulimits` is used. +### GPU support + +Tasks can request GPU by setting `gpuCount` in the task runtime attribute. For instance: +``` +task gpu_queue_task { + input { + ... + } + + command <<< + ... + >>> + output {} + + runtime { + queueArn: "arn:aws:batch:us-west-2:12345678910:job-queue/quekx-gpu-queue" + docker: "xxxx" + maxRetries: 1 + cpu: "1" + gpuCount: 1 + memory: "2 GB" + } +} +``` +the gpuCount value will be passed to AWS Batch as part of [resourceRequirements](https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#ContainerProperties-resourceRequirements). +You will need to use this feature in conjunction with a aws queue that has GPU instances (see [compute-environment](/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/DEPLOY.md#compute-environment) for more inforamtion) + + ### Call Caching with ECR private AWS ECR is a private container registry, for which access can be regulated using IAM. Call caching is possible by setting up the following configuration: diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala index 0c202aa8946..61914626ed7 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala @@ -46,12 +46,14 @@ import org.scalatest.PrivateMethodTester import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider -import software.amazon.awssdk.services.batch.model.{ContainerDetail, EvaluateOnExit, JobDetail, KeyValuePair, RetryAction, RetryStrategy} +import software.amazon.awssdk.services.batch.model.{ContainerDetail, EvaluateOnExit, JobDetail, KeyValuePair, ResourceRequirement, RetryAction, RetryStrategy} import spray.json.{JsObject, JsString} import wdl4s.parser.MemoryUnit import wom.format.MemorySize import wom.graph.CommandCallNode +import scala.jdk.javaapi.CollectionConverters + class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with PrivateMethodTester { import AwsBatchTestConfig._ @@ -109,6 +111,7 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi val cpu: Int Refined Positive = 2 val runtimeAttributes: AwsBatchRuntimeAttributes = new AwsBatchRuntimeAttributes( cpu = cpu, + gpuCount = 0, zones = Vector("us-east-1"), memory = MemorySize(2.0, MemoryUnit.GB), disks = Seq.empty, @@ -408,10 +411,8 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi ).build() val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime)) - val jobDefinitionName = jobDefinition.name val expected = jobDefinition.retryStrategy expected should equal (builder) - jobDefinitionName should equal ("cromwell_ubuntu_latest_656d5a7e7cd016d2360b27bc5ee75018d91a777a") } it should "use RetryStrategy evaluateOnExit should be case insensitive" in { @@ -424,9 +425,21 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi ).build() val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime)) - val jobDefinitionName = jobDefinition.name val expected = jobDefinition.retryStrategy expected should equal(builder) - jobDefinitionName should equal("cromwell_ubuntu_latest_66a335d761780e64e6b154339c5f1db2f0783f96") } + + it should "GPU is not set at job definition even if provided" in { + val runtime = runtimeAttributes.copy( + gpuCount = 1 + ) + + val expected = List( + ResourceRequirement.builder().`type`("VCPU").value(s"$cpu").build(), + ResourceRequirement.builder().`type`("MEMORY").value("2048").build() + ) + + val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime)) + val actual = jobDefinition.containerProperties.resourceRequirements + expected should equal(CollectionConverters.asScala(actual).toSeq)} } diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala index 944577d229d..a9411b574b5 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala @@ -58,7 +58,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout ))) } - val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"), + val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), 0, Vector("us-east-1a", "us-east-1b"), MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()), "ubuntu:latest", @@ -74,7 +74,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout false ) - val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"), + val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), 0, Vector("us-east-1a", "us-east-1b"), MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()), "ubuntu:latest",