Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-Attach to running alignment chunks on retry #372

Merged
merged 13 commits into from
Jul 10, 2024
95 changes: 74 additions & 21 deletions lib/idseq_utils/idseq_utils/batch_run_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import logging
import os
Expand All @@ -9,7 +10,7 @@
from os import listdir
from multiprocessing import Pool
from subprocess import run
from typing import Dict, List
from typing import Dict, List, Optional
from urllib.parse import urlparse

from idseq_utils.diamond_scatter import blastx_join
Expand All @@ -19,6 +20,11 @@
from botocore.exceptions import ClientError
from botocore.config import Config

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
log = logging.getLogger(__name__)

MAX_CHUNKS_IN_FLIGHT = 30 # TODO: remove this constant, currently does nothing since we have at most 30 index chunks
Expand Down Expand Up @@ -83,25 +89,62 @@ def _get_job_status(job_id, use_batch_api=False):
raise e


class BatchJobCache:
"""
BatchJobCache saves job IDs so the coordinator can re-attach to running batch jobs when the coordinator fails

The output should always be the same if the inputs are the same, however we also incorporate the batch_args
into the cache because a retry on spot vs on demand will result in a different batch queue.
"""
def __init__(self, bucket: str, prefix: str, inputs: Dict[str, str]):
self.bucket = bucket
self.prefix = prefix
self.inputs = inputs

def _key(self, batch_args: Dict) -> str:
hash = hashlib.sha256()
cache_dict = {"inputs": self.inputs, "batch_args": batch_args}
hash.update(json.dumps(cache_dict, sort_keys=True).encode())
return os.path.join(self.prefix, hash.hexdigest())

def get(self, batch_args: Dict) -> Optional[str]:
try:
resp = _s3_client.get_object(Bucket=self.bucket, Key=self._key(batch_args))
return resp["Body"].read().decode()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
return None
else:
raise e

def put(self, batch_args: Dict, job_id: str):
_s3_client.put_object(
Bucket=self.bucket,
Key=self._key(batch_args),
Body=job_id.encode(),
Tagging="AlignmentCoordination=True",
)


def _run_batch_job(
job_name: str,
job_queue: str,
job_definition: str,
environment: Dict[str, str],
retries: int,
cache: BatchJobCache,
):
response = _batch_client.submit_job(
jobName=job_name,
jobQueue=job_queue,
jobDefinition=job_definition,
containerOverrides={
submit_args = {
"jobName": job_name,
"jobQueue": job_queue,
"jobDefinition": job_definition,
"containerOverrides": {
"environment": [{"name": k, "value": v} for k, v in environment.items()],
"memory": 130816,
"vcpus": 24,
},
retryStrategy={"attempts": retries},
)
job_id = response["jobId"]
"retryStrategy": {"attempts": retries},
}

def _log_status(status: str):
level = logging.INFO if status != "FAILED" else logging.ERROR
Expand All @@ -119,7 +162,14 @@ def _log_status(status: str):
),
)

_log_status("SUBMITTED")
job_id = cache.get(submit_args)
if job_id:
log.info(f"reattach to batch job: {job_id}")
else:
response = _batch_client.submit_job(**submit_args)
job_id = response["jobId"]
cache.put(submit_args, job_id)
_log_status("SUBMITTED")

delay = 60 + random.randint(
-60 // 2, 60 // 2
Expand Down Expand Up @@ -194,26 +244,19 @@ def _job_queue(provisioning_model: str):
input_bucket, input_key = _bucket_and_key(wdl_input_uri)

wdl_output_uri = os.path.join(chunk_dir, f"{chunk_id}-output.json")
output_bucket, output_key = _bucket_and_key(wdl_output_uri)

wdl_workflow_uri = f"s3://idseq-workflows/{aligner}-{aligner_wdl_version}/{aligner}.wdl"

# if this job fails we don't want to re-run chunks that have already been processed
# the presence of the output file means the chunk has already been processed
try:
_s3_client.head_object(Bucket=output_bucket, Key=output_key)
log.info(f"skipping chunk, output already exists: {wdl_output_uri}")
return
except ClientError as e:
# raise the error if it is anything other than "not found"
if e.response["Error"]["Code"] != "404":
raise e
cache_prefix_uri = os.path.join(chunk_dir, "batch_job_cache/")
cache_bucket, cache_prefix = _bucket_and_key(cache_prefix_uri)
cache = BatchJobCache(cache_bucket, cache_prefix, inputs)

_s3_client.put_object(
Bucket=input_bucket,
Key=input_key,
Body=json.dumps(inputs).encode(),
ContentType="application/json",
Tagging="AlignmentCoordination=True",
)

environment = {
Expand All @@ -231,6 +274,7 @@ def _job_queue(provisioning_model: str):
job_definition=job_definition,
environment=environment,
retries=2,
cache=cache,
)
except BatchJobFailed:
_run_batch_job(
Expand All @@ -239,6 +283,7 @@ def _job_queue(provisioning_model: str):
job_definition=job_definition,
environment=environment,
retries=1,
cache=cache,
)


Expand All @@ -263,6 +308,7 @@ def run_alignment(
):
bucket, prefix = _bucket_and_key(db_path)
chunk_dir = os.path.join(input_dir, f"{aligner}-chunks")
_, chunk_prefix = _bucket_and_key(chunk_dir)
chunks = (
[
input_dir,
Expand All @@ -281,9 +327,16 @@ def run_alignment(
run(["s3parcp", "--recursive", chunk_dir, "chunks"], check=True)
if os.path.exists(os.path.join("chunks", "cache")):
shutil.rmtree(os.path.join("chunks", "cache"))
if os.path.exists(os.path.join("chunks", "batch_job_cache")):
shutil.rmtree(os.path.join("chunks", "batch_job_cache"))
for fn in listdir("chunks"):
if fn.endswith("json"):
os.remove(os.path.join("chunks", fn))
_s3_client.put_object_tagging(
Bucket=bucket,
Key=os.path.join(chunk_prefix, fn),
Tagging={"TagSet": [{"Key": "AlignmentCoordination", "Value": "True"}]},
)
if aligner == "diamond":
blastx_join("chunks", result_path, aligner_args, *queries)
else:
Expand Down
Loading