Skip to content

Commit

Permalink
add gpu count (#41)
Browse files Browse the repository at this point in the history
* add gpu count

* fix typo

---------

Co-authored-by: quekx <quekx@gene.com>
Co-authored-by: Henrique Ribeiro <henriqueribeiro@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 22, 2023
1 parent 6eb71bb commit b391c0f
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 11 deletions.
1 change: 1 addition & 0 deletions backend/src/main/scala/cromwell/backend/backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ object CommonBackendConfigurationAttributes {
"default-runtime-attributes.zones",
"default-runtime-attributes.continueOnReturnCode",
"default-runtime-attributes.cpu",
"default-runtime-attributes.gpuCount",
"default-runtime-attributes.noAddress",
"default-runtime-attributes.docker",
"default-runtime-attributes.queueArn",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ object AwsBatchAttributes {
"awsBatchRetryAttempts",
"awsBatchEvaluateOnExit",
"ulimits",
"gpuCount",
"efsDelocalize",
"efsMakeMD5",
"tagResources",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
Log.debug(s"Submitting taskId: $taskId, job definition : $definitionArn, script: $batch_script")
Log.info(s"Submitting taskId: $rootworkflowId::$taskId, script: $batch_script")

//provide job environment variables, vcpu and memory
var resourceRequirements: Seq[ResourceRequirement] = Seq(
ResourceRequirement.builder().`type`(ResourceType.VCPU).value(runtimeAttributes.cpu.##.toString).build(),
ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build()
)

if (runtimeAttributes.gpuCount > 0) {
val gpuRequirement = ResourceRequirement.builder().`type`(ResourceType.GPU).value(runtimeAttributes.gpuCount.toString)
resourceRequirements = resourceRequirements :+ gpuRequirement.build()
}

// prepare the job request
var submitJobRequest = SubmitJobRequest.builder()
.jobName(sanitize(jobDescriptor.taskCall.fullyQualifiedName))
Expand All @@ -581,10 +592,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
.environment(
generateEnvironmentKVPairs(runtimeAttributes.scriptS3BucketName, scriptKeyPrefix, scriptKey): _*
)
.resourceRequirements(
ResourceRequirement.builder().`type`(ResourceType.VCPU).value(runtimeAttributes.cpu.##.toString).build(),
ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build()
)
.resourceRequirements(resourceRequirements.asJava)
.build()
)
.jobQueue(runtimeAttributes.queueArn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ trait AwsBatchJobDefinitionBuilder {
tagResources
)

// To reuse job definition for gpu and gpu-runs, we will create a job definition that does not gpu requirements
// since aws batch does not allow you to set gpu as 0 when you dont need it. you will always need cpu and memory
(ContainerProperties.builder()
.image(context.runtimeAttributes.dockerImage)
.command(packedCommand.asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import com.typesafe.config.{ConfigException, ConfigValueFactory}

import scala.util.matching.Regex
import org.slf4j.{Logger, LoggerFactory}
import wom.RuntimeAttributesKeys.GpuKey

import scala.util.{Failure, Success, Try}
import scala.jdk.CollectionConverters._
Expand All @@ -56,6 +57,7 @@ import scala.jdk.CollectionConverters._
/**
* Attributes that are provided to the job at runtime
* @param cpu number of vCPU
* @param gpuCount number of gpu
* @param zones the aws availability zones to run in
* @param memory memory to allocate
* @param disks a sequence of disk volumes
Expand All @@ -74,6 +76,7 @@ import scala.jdk.CollectionConverters._
* @param tagResources should we tag resources
*/
case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
gpuCount: Int,
zones: Vector[String],
memory: MemorySize,
disks: Seq[AwsBatchVolume],
Expand Down Expand Up @@ -125,6 +128,10 @@ object AwsBatchRuntimeAttributes {
private def cpuValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instance
.withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin)

private def gpuCountValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int] = {
GpuCountValidation(GpuKey).withDefault(GpuCountValidation(GpuKey).configDefaultWomValue(runtimeConfig).getOrElse(WomInteger(0)))
}

private def cpuMinValidation(runtimeConfig: Option[Config]):RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instanceMin
.withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin)

Expand Down Expand Up @@ -222,6 +229,7 @@ object AwsBatchRuntimeAttributes {
def validationsS3backend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
cpuValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
gpuCountValidation(runtimeConfig),
disksValidation(runtimeConfig),
zonesValidation(runtimeConfig),
memoryValidation(runtimeConfig),
Expand All @@ -240,6 +248,7 @@ object AwsBatchRuntimeAttributes {
def validationsLocalBackend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
cpuValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
gpuCountValidation(runtimeConfig),
disksValidation(runtimeConfig),
zonesValidation(runtimeConfig),
memoryValidation(runtimeConfig),
Expand All @@ -264,6 +273,7 @@ object AwsBatchRuntimeAttributes {

def apply(validatedRuntimeAttributes: ValidatedRuntimeAttributes, runtimeAttrsConfig: Option[Config], fileSystem:String): AwsBatchRuntimeAttributes = {
val cpu: Int Refined Positive = RuntimeAttributesValidation.extract(cpuValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val gpuCount: Int = RuntimeAttributesValidation.extract(gpuCountValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val zones: Vector[String] = RuntimeAttributesValidation.extract(ZonesValidation, validatedRuntimeAttributes)
val memory: MemorySize = RuntimeAttributesValidation.extract(memoryValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val disks: Seq[AwsBatchVolume] = RuntimeAttributesValidation.extract(disksValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
Expand All @@ -285,6 +295,7 @@ object AwsBatchRuntimeAttributes {
val tagResources: Boolean = RuntimeAttributesValidation.extract(awsBatchtagResourcesValidation(runtimeAttrsConfig),validatedRuntimeAttributes)
new AwsBatchRuntimeAttributes(
cpu,
gpuCount,
zones,
memory,
disks,
Expand Down Expand Up @@ -479,6 +490,23 @@ object DisksValidation extends RuntimeAttributesValidation[Seq[AwsBatchVolume]]
s"Expecting $key runtime attribute to be a comma separated String or Array[String]"
}

object GpuCountValidation {
def apply(key: String): GpuCountValidation = new GpuCountValidation(key)
}

class GpuCountValidation(key: String) extends IntRuntimeAttributesValidation(key) {
override protected def validateValue: PartialFunction[WomValue, ErrorOr[Int]] = {
case womValue if WomIntegerType.coerceRawValue(womValue).isSuccess =>
WomIntegerType.coerceRawValue(womValue).get match {
case WomInteger(value) =>
if (value.toInt < 0)
s"Expecting $key runtime attribute value greater than or equal to 0".invalidNel
else
value.toInt.validNel
}
}
}

object AwsBatchRetryAttemptsValidation {
def apply(key: String): AwsBatchRetryAttemptsValidation = new AwsBatchRetryAttemptsValidation(key)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,34 @@ Parameter description:
- Type: Integer
- Required: Yes, when `ulimits` is used.

### GPU support

Tasks can request GPU by setting `gpuCount` in the task runtime attribute. For instance:
```
task gpu_queue_task {
input {
...
}
command <<<
...
>>>
output {}
runtime {
queueArn: "arn:aws:batch:us-west-2:12345678910:job-queue/quekx-gpu-queue"
docker: "xxxx"
maxRetries: 1
cpu: "1"
gpuCount: 1
memory: "2 GB"
}
}
```
the gpuCount value will be passed to AWS Batch as part of [resourceRequirements](https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#ContainerProperties-resourceRequirements).
You will need to use this feature in conjunction with a aws queue that has GPU instances (see [compute-environment](/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/DEPLOY.md#compute-environment) for more inforamtion)


### Call Caching with ECR private

AWS ECR is a private container registry, for which access can be regulated using IAM. Call caching is possible by setting up the following configuration:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ import org.scalatest.PrivateMethodTester
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.matchers.should.Matchers
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider
import software.amazon.awssdk.services.batch.model.{ContainerDetail, EvaluateOnExit, JobDetail, KeyValuePair, RetryAction, RetryStrategy}
import software.amazon.awssdk.services.batch.model.{ContainerDetail, EvaluateOnExit, JobDetail, KeyValuePair, ResourceRequirement, RetryAction, RetryStrategy}
import spray.json.{JsObject, JsString}
import wdl4s.parser.MemoryUnit
import wom.format.MemorySize
import wom.graph.CommandCallNode

import scala.jdk.javaapi.CollectionConverters

class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with PrivateMethodTester {
import AwsBatchTestConfig._

Expand Down Expand Up @@ -109,6 +111,7 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
val cpu: Int Refined Positive = 2
val runtimeAttributes: AwsBatchRuntimeAttributes = new AwsBatchRuntimeAttributes(
cpu = cpu,
gpuCount = 0,
zones = Vector("us-east-1"),
memory = MemorySize(2.0, MemoryUnit.GB),
disks = Seq.empty,
Expand Down Expand Up @@ -408,10 +411,8 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
).build()

val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime))
val jobDefinitionName = jobDefinition.name
val expected = jobDefinition.retryStrategy
expected should equal (builder)
jobDefinitionName should equal ("cromwell_ubuntu_latest_656d5a7e7cd016d2360b27bc5ee75018d91a777a")
}

it should "use RetryStrategy evaluateOnExit should be case insensitive" in {
Expand All @@ -424,9 +425,21 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
).build()

val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime))
val jobDefinitionName = jobDefinition.name
val expected = jobDefinition.retryStrategy
expected should equal(builder)
jobDefinitionName should equal("cromwell_ubuntu_latest_66a335d761780e64e6b154339c5f1db2f0783f96")
}

it should "GPU is not set at job definition even if provided" in {
val runtime = runtimeAttributes.copy(
gpuCount = 1
)

val expected = List(
ResourceRequirement.builder().`type`("VCPU").value(s"$cpu").build(),
ResourceRequirement.builder().`type`("MEMORY").value("2048").build()
)

val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime))
val actual = jobDefinition.containerProperties.resourceRequirements
expected should equal(CollectionConverters.asScala(actual).toSeq)}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout
)))
}

val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),
val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), 0, Vector("us-east-1a", "us-east-1b"),

MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
Expand All @@ -74,7 +74,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout
false
)

val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),
val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), 0, Vector("us-east-1a", "us-east-1b"),

MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
Expand Down

0 comments on commit b391c0f

Please sign in to comment.