Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Batch submit task #41

Merged
merged 11 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- `batch_submit` task - [#41](https://github.com/PrefectHQ/prefect-aws/issues/41)

### Changed

### Deprecated
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ nav:
- Credentials: credentials.md
- S3: s3.md
- Secrets Manager: secrets_manager.md
- Batch: batch.md
77 changes: 77 additions & 0 deletions prefect_aws/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tasks for interacting with AWS Batch"""

from functools import partial
from typing import Any, Dict, Optional

from anyio import to_thread
from prefect import get_run_logger, task

from prefect_aws.credentials import AwsCredentials


@task
async def batch_submit(
job_name: str,
job_queue: str,
job_definition: str,
aws_credentials: AwsCredentials,
**batch_kwargs: Optional[Dict[str, Any]],
):
"""
Submit a job to the AWS Batch job service.

Args:
job_name: The AWS batch job name.
job_definition: The AWS batch job definition.
job_queue: Name of the AWS batch job queue.
aws_credentials: Credentials to use for authentication with AWS.
batch_kwargs: Additional keyword arguments to pass to the boto3
`submit_job` function. See the documentation for
[submit_job](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html#Batch.Client.submit_job)
for more details.

Returns:
The id corresponding to the job.

Example:
Submits a job to batch.

```python
from prefect import flow
from prefect_aws import AwsCredentials
from prefect_aws.batch import batch_submit


@flow
def example_batch_submit_flow():
aws_credentials = AwsCredentials(
aws_access_key_id="acccess_key_id",
aws_secret_access_key="secret_access_key"
)
job_id = batch_submit(
"job_name",
"job_definition",
"job_queue",
aws_credentials
)
return job_id

example_batch_submit_flow()
```

""" # noqa
logger = get_run_logger()
logger.info("Preparing to submit %s job to %s job queue", job_name, job_queue)

batch_kwargs = batch_kwargs or {}
batch_client = aws_credentials.get_boto3_session().client("batch")

submit_job = partial(
batch_client.submit_job,
jobName=job_name,
jobQueue=job_queue,
jobDefinition=job_definition,
**batch_kwargs,
)
response = await to_thread.run_sync(submit_job)
return response["jobId"]
7 changes: 0 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import pytest
from prefect.utilities.testing import prefect_test_harness

from prefect_aws import AwsCredentials
from prefect_aws.client_parameters import AwsClientParameters


@pytest.fixture(scope="session", autouse=True)
def prefect_db():
with prefect_test_harness():
yield

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this fixture?

Copy link
Contributor Author

@ahuang11 ahuang11 Jun 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was crashing tests earlier (e.g. from prefect.utilities.testing import prefect_test_harness), but now I think it's not used since the tests are still passing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asyncio: mode=auto
collected 58 items

tests/test_batch.py .                                                                                                                                                                              [  1%]
tests/test_client_parameters.py ....                                                                                                                                                               [  8%]
tests/test_s3.py ...............                                                                                                                                                                   [ 34%]
tests/test_secrets_manager.py ......................................                                                                                                                               [100%]

===================================================================================== 58 passed in 103.01s (0:01:43) =====================================================================================

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the comments in #27 where it was added, it's a nice thing to have to spin up a temporary DB per test. I think we should leave it in and correct the import path (from prefect.test.utilities import prefect_test_harness).


@pytest.fixture
def aws_credentials():
return AwsCredentials(
Expand Down
89 changes: 89 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from uuid import UUID

import boto3
import pytest
from moto import mock_batch, mock_iam, mock_s3
from prefect import flow

from prefect_aws.batch import batch_submit


@pytest.fixture(scope="function")
def s3_client(aws_credentials):
with mock_s3():
yield boto3.client("s3", region_name="us-east-1")


@pytest.fixture(scope="function")
def batch_client(aws_credentials):
with mock_batch():
yield boto3.client("batch", region_name="us-east-1")


@pytest.fixture(scope="function")
def iam_client(aws_credentials):
with mock_iam():
yield boto3.client("iam", region_name="us-east-1")


@pytest.fixture()
def job_queue_arn(iam_client, batch_client):
iam_role = iam_client.create_role(
RoleName="test_batch_client",
AssumeRolePolicyDocument="string",
)
iam_arn = iam_role.get("Role").get("Arn")

compute_environment = batch_client.create_compute_environment(
computeEnvironmentName="test_batch_ce", type="UNMANAGED", serviceRole=iam_arn
)

compute_environment_arn = compute_environment.get("computeEnvironmentArn")

created_queue = batch_client.create_job_queue(
jobQueueName="test_batch_queue",
state="ENABLED",
priority=1,
computeEnvironmentOrder=[
{"order": 1, "computeEnvironment": compute_environment_arn},
],
)
job_queue_arn = created_queue.get("jobQueueArn")
return job_queue_arn


@pytest.fixture
def job_definition_arn(batch_client):
job_definition = batch_client.register_job_definition(
jobDefinitionName="test_batch_jobdef",
type="container",
containerProperties={
"image": "busybox",
"vcpus": 1,
"memory": 128,
"command": ["sleep", "2"],
},
)
job_definition_arn = job_definition.get("jobDefinitionArn")
return job_definition_arn


def test_batch_submit(job_queue_arn, job_definition_arn, aws_credentials):
@flow
def test_flow():
return batch_submit(
"batch_test_job",
job_queue_arn,
job_definition_arn,
aws_credentials,
)

flow_state = test_flow()
assert flow_state.is_completed

job_id = flow_state.result().result()
try:
UUID(str(job_id))
assert True, f"{job_id} is a valid UUID"
except ValueError:
assert False, f"{job_id} is not a valid UUID"