Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gpu count #41

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/src/main/scala/cromwell/backend/backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ object CommonBackendConfigurationAttributes {
"default-runtime-attributes.zones",
"default-runtime-attributes.continueOnReturnCode",
"default-runtime-attributes.cpu",
"default-runtime-attributes.gpuCount",
"default-runtime-attributes.noAddress",
"default-runtime-attributes.docker",
"default-runtime-attributes.queueArn",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ object AwsBatchAttributes {
"awsBatchRetryAttempts",
"awsBatchEvaluateOnExit",
"ulimits",
"gpuCount",
"efsDelocalize",
"efsMakeMD5",
"tagResources",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
Log.debug(s"Submitting taskId: $taskId, job definition : $definitionArn, script: $batch_script")
Log.info(s"Submitting taskId: $rootworkflowId::$taskId, script: $batch_script")

//provide job environment variables, vcpu and memory
var resourceRequirements: Seq[ResourceRequirement] = Seq(
ResourceRequirement.builder().`type`(ResourceType.VCPU).value(runtimeAttributes.cpu.##.toString).build(),
ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build()
)

if (runtimeAttributes.gpuCount > 0) {
val gpuRequirement = ResourceRequirement.builder().`type`(ResourceType.GPU).value(runtimeAttributes.gpuCount.toString)
resourceRequirements = resourceRequirements :+ gpuRequirement.build()
}

// prepare the job request
var submitJobRequest = SubmitJobRequest.builder()
.jobName(sanitize(jobDescriptor.taskCall.fullyQualifiedName))
Expand All @@ -581,10 +592,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
.environment(
generateEnvironmentKVPairs(runtimeAttributes.scriptS3BucketName, scriptKeyPrefix, scriptKey): _*
)
.resourceRequirements(
ResourceRequirement.builder().`type`(ResourceType.VCPU).value(runtimeAttributes.cpu.##.toString).build(),
ResourceRequirement.builder().`type`(ResourceType.MEMORY).value(runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString).build()
)
.resourceRequirements(resourceRequirements.asJava)
.build()
)
.jobQueue(runtimeAttributes.queueArn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ trait AwsBatchJobDefinitionBuilder {
tagResources
)

// To reuse job definition for gpu and gpu-runs, we will create a job definition that does not gpu requirements
// since aws batch does not allow you to set gpu as 0 when you dont need it. you will always need cpu and memory
Comment on lines +185 to +186
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if this is desired - or we should create two separate definition

(ContainerProperties.builder()
.image(context.runtimeAttributes.dockerImage)
.command(packedCommand.asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import com.typesafe.config.{ConfigException, ConfigValueFactory}

import scala.util.matching.Regex
import org.slf4j.{Logger, LoggerFactory}
import wom.RuntimeAttributesKeys.GpuKey

import scala.util.{Failure, Success, Try}
import scala.jdk.CollectionConverters._
Expand All @@ -56,6 +57,7 @@ import scala.jdk.CollectionConverters._
/**
* Attributes that are provided to the job at runtime
* @param cpu number of vCPU
* @param gpuCount number of gpu
* @param zones the aws availability zones to run in
* @param memory memory to allocate
* @param disks a sequence of disk volumes
Expand All @@ -74,6 +76,7 @@ import scala.jdk.CollectionConverters._
* @param tagResources should we tag resources
*/
case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
gpuCount: Int,
zones: Vector[String],
memory: MemorySize,
disks: Seq[AwsBatchVolume],
Expand Down Expand Up @@ -125,6 +128,10 @@ object AwsBatchRuntimeAttributes {
private def cpuValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instance
.withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin)

private def gpuCountValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int] = {
GpuCountValidation(GpuKey).withDefault(GpuCountValidation(GpuKey).configDefaultWomValue(runtimeConfig).getOrElse(WomInteger(0)))
}

private def cpuMinValidation(runtimeConfig: Option[Config]):RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instanceMin
.withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin)

Expand Down Expand Up @@ -222,6 +229,7 @@ object AwsBatchRuntimeAttributes {
def validationsS3backend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
cpuValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
gpuCountValidation(runtimeConfig),
disksValidation(runtimeConfig),
zonesValidation(runtimeConfig),
memoryValidation(runtimeConfig),
Expand All @@ -240,6 +248,7 @@ object AwsBatchRuntimeAttributes {
def validationsLocalBackend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
cpuValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
gpuCountValidation(runtimeConfig),
disksValidation(runtimeConfig),
zonesValidation(runtimeConfig),
memoryValidation(runtimeConfig),
Expand All @@ -264,6 +273,7 @@ object AwsBatchRuntimeAttributes {

def apply(validatedRuntimeAttributes: ValidatedRuntimeAttributes, runtimeAttrsConfig: Option[Config], fileSystem:String): AwsBatchRuntimeAttributes = {
val cpu: Int Refined Positive = RuntimeAttributesValidation.extract(cpuValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val gpuCount: Int = RuntimeAttributesValidation.extract(gpuCountValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val zones: Vector[String] = RuntimeAttributesValidation.extract(ZonesValidation, validatedRuntimeAttributes)
val memory: MemorySize = RuntimeAttributesValidation.extract(memoryValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val disks: Seq[AwsBatchVolume] = RuntimeAttributesValidation.extract(disksValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
Expand All @@ -285,6 +295,7 @@ object AwsBatchRuntimeAttributes {
val tagResources: Boolean = RuntimeAttributesValidation.extract(awsBatchtagResourcesValidation(runtimeAttrsConfig),validatedRuntimeAttributes)
new AwsBatchRuntimeAttributes(
cpu,
gpuCount,
zones,
memory,
disks,
Expand Down Expand Up @@ -479,6 +490,23 @@ object DisksValidation extends RuntimeAttributesValidation[Seq[AwsBatchVolume]]
s"Expecting $key runtime attribute to be a comma separated String or Array[String]"
}

object GpuCountValidation {
def apply(key: String): GpuCountValidation = new GpuCountValidation(key)
}

class GpuCountValidation(key: String) extends IntRuntimeAttributesValidation(key) {
override protected def validateValue: PartialFunction[WomValue, ErrorOr[Int]] = {
case womValue if WomIntegerType.coerceRawValue(womValue).isSuccess =>
WomIntegerType.coerceRawValue(womValue).get match {
case WomInteger(value) =>
if (value.toInt < 0)
s"Expecting $key runtime attribute value greater than or equal to 0".invalidNel
else
value.toInt.validNel
}
}
}

object AwsBatchRetryAttemptsValidation {
def apply(key: String): AwsBatchRetryAttemptsValidation = new AwsBatchRetryAttemptsValidation(key)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,34 @@ Parameter description:
- Type: Integer
- Required: Yes, when `ulimits` is used.

### GPU support

Tasks can request GPU by setting `gpuCount` in the task runtime attribute. For instance:
```
task gpu_queue_task {
input {
...
}

command <<<
...
>>>
output {}

runtime {
queueArn: "arn:aws:batch:us-west-2:12345678910:job-queue/quekx-gpu-queue"
docker: "xxxx"
maxRetries: 1
cpu: "1"
gpuCount: 1
memory: "2 GB"
}
}
```
the gpuCount value will be passed to AWS Batch as part of [resourceRequirements](https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html#ContainerProperties-resourceRequirements).
You will need to use this feature in conjunction with a aws queue that has GPU instances (see [compute-environment](/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/DEPLOY.md#compute-environment) for more inforamtion)


### Call Caching with ECR private

AWS ECR is a private container registry, for which access can be regulated using IAM. Call caching is possible by setting up the following configuration:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ import org.scalatest.PrivateMethodTester
import org.scalatest.flatspec.AnyFlatSpecLike
import org.scalatest.matchers.should.Matchers
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider
import software.amazon.awssdk.services.batch.model.{ContainerDetail, EvaluateOnExit, JobDetail, KeyValuePair, RetryAction, RetryStrategy}
import software.amazon.awssdk.services.batch.model.{ContainerDetail, EvaluateOnExit, JobDetail, KeyValuePair, ResourceRequirement, RetryAction, RetryStrategy}
import spray.json.{JsObject, JsString}
import wdl4s.parser.MemoryUnit
import wom.format.MemorySize
import wom.graph.CommandCallNode

import scala.jdk.javaapi.CollectionConverters

class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with PrivateMethodTester {
import AwsBatchTestConfig._

Expand Down Expand Up @@ -109,6 +111,7 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
val cpu: Int Refined Positive = 2
val runtimeAttributes: AwsBatchRuntimeAttributes = new AwsBatchRuntimeAttributes(
cpu = cpu,
gpuCount = 0,
zones = Vector("us-east-1"),
memory = MemorySize(2.0, MemoryUnit.GB),
disks = Seq.empty,
Expand Down Expand Up @@ -408,10 +411,8 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
).build()

val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime))
val jobDefinitionName = jobDefinition.name
val expected = jobDefinition.retryStrategy
expected should equal (builder)
jobDefinitionName should equal ("cromwell_ubuntu_latest_656d5a7e7cd016d2360b27bc5ee75018d91a777a")
}

it should "use RetryStrategy evaluateOnExit should be case insensitive" in {
Expand All @@ -424,9 +425,21 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
).build()

val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime))
val jobDefinitionName = jobDefinition.name
val expected = jobDefinition.retryStrategy
expected should equal(builder)
jobDefinitionName should equal("cromwell_ubuntu_latest_66a335d761780e64e6b154339c5f1db2f0783f96")
}

it should "GPU is not set at job definition even if provided" in {
val runtime = runtimeAttributes.copy(
gpuCount = 1
)

val expected = List(
ResourceRequirement.builder().`type`("VCPU").value(s"$cpu").build(),
ResourceRequirement.builder().`type`("MEMORY").value("2048").build()
)

val jobDefinition = StandardAwsBatchJobDefinitionBuilder.build(batchJobDefintion.copy(runtimeAttributes = runtime))
val actual = jobDefinition.containerProperties.resourceRequirements
expected should equal(CollectionConverters.asScala(actual).toSeq)}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout
)))
}

val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),
val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), 0, Vector("us-east-1a", "us-east-1b"),

MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
Expand All @@ -74,7 +74,7 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout
false
)

val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),
val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), 0, Vector("us-east-1a", "us-east-1b"),

MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
Expand Down
Loading