Skip to content

Commit

Permalink
Cache and bookkeep existing branches in remote workflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Dec 15, 2024
1 parent 6f73b18 commit 37592c9
Showing 1 changed file with 71 additions and 50 deletions.
121 changes: 71 additions & 50 deletions law/workflow/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,19 @@ def __init__(self, *args, **kwargs):
self._job_manager_setup_kwargs = no_value

# boolean per job num denoting if a job should be / was skipped
self.skip_jobs = {}
self._skip_jobs = {}

# retry counts per job num
self.job_retries = defaultdict(int)
self._job_retries = defaultdict(int)

# cached output() return value, set in run()
self._outputs = None
# cached output() return value
self._cached_output = None

# flag that denotes whether a submission was done befire, set in run()
self._submitted = False

# initially existing keys of the "collection" output (= complete branch tasks), set in run()
self._initially_existing_branches = []
# set of existing branches that is kept track of during processing
self._existing_branches = None

# flag denoting if jobs were cancelled or cleaned up (i.e. controlled)
self._controlled_jobs = False
Expand Down Expand Up @@ -349,25 +349,44 @@ def _cleanup_jobs(self):
"""
return isinstance(getattr(self.task, "cleanup_jobs", None), bool) and self.task.cleanup_jobs

def _get_cached_output(self):
if self._cached_output is None:
self._cached_output = self.output()
return self._cached_output

def _get_existing_branches(self, sync=False, collection=None):
if self._existing_branches is None:
sync = True

if sync:
# initialize with set
self._existing_branches = set()
# add initial branches existing in output collection
if collection is None:
collection = self._get_cached_output().get("collection")
if collection is not None:
keys = collection.count(existing=True, keys=True)[1]
self._existing_branches |= set(keys)

return self._existing_branches

def _can_skip_job(self, job_num, branches):
"""
Returns *True* when a job can be potentially skipped, which is the case when all branch
tasks given by *branches* are complete.
"""
if job_num not in self.skip_jobs:
self.skip_jobs[job_num] = all(
(b in self._initially_existing_branches) or self.task.as_branch(b).complete()
for b in branches
)
if job_num not in self._skip_jobs:
existing_branches = self._get_existing_branches()
self._skip_jobs[job_num] = all((b in existing_branches) for b in branches)

# when the job is skipped, ensure that a job data entry exists and set the status
if self.skip_jobs[job_num]:
if self._skip_jobs[job_num]:
if job_num not in self.job_data.jobs:
self.job_data.jobs[job_num] = self.job_data_cls.job_data(branches=branches)
if not self.job_data.jobs[job_num]["status"]:
self.job_data.jobs[job_num]["status"] = self.job_manager.FINISHED

return self.skip_jobs[job_num]
return self._skip_jobs[job_num]

def _get_job_kwargs(self, name):
attr = "{}_job_kwargs_{}".format(self.workflow_type, name)
Expand Down Expand Up @@ -537,19 +556,20 @@ def _maximum_resources(cls, resources, n_parallel):
return dict(zip(keys, merged_counts))

def process_resources(self, force=False):
if self._initial_process_resources is None or force:
task = self.task
task = self.task

job_resources = {}
# collect resources over all branches if not just controlling running jobs
if (
not task.is_controlling_remote_jobs() and
(self._initial_process_resources is None or force)
):
get_job_resources = self._get_task_attribute("job_resources")

branch_chunks = iter_chunks(task.branch_map.keys(), task.tasks_per_job)
for job_num, branches in enumerate(branch_chunks, 1):
if self._can_skip_job(job_num, branches):
continue
job_resources[job_num] = get_job_resources(job_num, branches)

self._initial_process_resources = job_resources
self._initial_process_resources = {
job_num: get_job_resources(job_num, branches)
for job_num, branches in enumerate(branch_chunks, 1)
if not self._can_skip_job(job_num, branches)
}

if not self._initial_process_resources:
return {}
Expand Down Expand Up @@ -614,8 +634,9 @@ def dump_job_data(self):
self.job_data["dashboard_config"] = self.dashboard.get_persistent_config()

# write the job data to the output file
output = self._get_cached_output()
with self._dump_lock:
self._outputs["jobs"].dump(self.job_data, formatter="json", indent=4)
output["jobs"].dump(self.job_data, formatter="json", indent=4)

logger.debug("job data dumped")

Expand All @@ -633,29 +654,28 @@ def _run_impl(self):
performs job cancelling or cleaning, depending on the task parameters.
"""
task = self.task
self._outputs = self.output()
output = self._get_cached_output()

# create the job dashboard interface
self.dashboard = task.create_job_dashboard() or NoJobDashboard()

# read job data and reset some values
self._submitted = not task.ignore_submission and self._outputs["jobs"].exists()
self._submitted = not task.ignore_submission and output["jobs"].exists()
if self._submitted:
# load job data and cast job ids
self.job_data.update(self._outputs["jobs"].load(formatter="json"))
self.job_data.update(output["jobs"].load(formatter="json"))
for job_data in six.itervalues(self.job_data.jobs):
job_data["job_id"] = self.job_manager.cast_job_id(job_data["job_id"])

# sync other settings
task.tasks_per_job = self.job_data.tasks_per_job
self.dashboard.apply_config(self.job_data.dashboard_config)

# store the initially complete branches
# store initially complete branches
outputs_existing = False
if "collection" in self._outputs:
collection = self._outputs["collection"]
count, keys = collection.count(keys=True)
self._initially_existing_branches = keys
if "collection" in output:
collection = output["collection"]
count = len(self._get_existing_branches(collection=collection))
outputs_existing = count >= collection._abs_threshold()

# cancel jobs?
Expand Down Expand Up @@ -683,7 +703,7 @@ def _run_impl(self):

# ensure the output directory exists
if not self._submitted:
self._outputs["jobs"].parent.touch()
output["jobs"].parent.touch()

try:
# instantiate the configured job file factory
Expand Down Expand Up @@ -1022,8 +1042,8 @@ def poll(self):
n_jobs = len(self.job_data)

# track finished and failed jobs in dicts holding status data
finished_jobs = []
failed_jobs = []
finished_jobs = set()
failed_jobs = set()

# the resources of yet unfinished jobs as claimed initially and reported to the scheduler
# and the maximum amount resources potentially claimed by the jobs
Expand Down Expand Up @@ -1074,7 +1094,7 @@ def poll(self):
if self._can_skip_job(job_num, data["branches"]):
data["status"] = self.job_manager.FINISHED
data["code"] = 0
finished_jobs.append(job_num)
finished_jobs.add(job_num)
continue

# mark as active or unknown
Expand Down Expand Up @@ -1156,7 +1176,7 @@ def poll(self):

# when the task picked up an existing submission file, then in the first polling
# iteration it might happen that a job is finished, but outputs of its tasks are
# not existing, e.g. when they were removed externaly and the job id is still known
# not existing, e.g. when they were removed externally and the job id is still known
# to the batch system; in this case, mark it as unknown and to be retried
if self._submitted and i == 0:
is_finished = data["status"] == self.job_manager.FINISHED
Expand All @@ -1176,21 +1196,21 @@ def poll(self):
time.sleep(check_completeness_delay)

# store jobs per status and take further actions depending on the status
pending_jobs = []
running_jobs = []
newly_failed_jobs = []
retry_jobs = []
pending_jobs = set()
running_jobs = set()
retry_jobs = set()
newly_failed_jobs = [] # need to preserve order
for job_num in active_jobs:
data = self.job_data.jobs[job_num]

if data["status"] == self.job_manager.PENDING:
pending_jobs.append(job_num)
pending_jobs.add(job_num)
task.forward_dashboard_event(self.dashboard, copy.deepcopy(data),
"status.pending", job_num)
continue

if data["status"] == self.job_manager.RUNNING:
running_jobs.append(job_num)
running_jobs.add(job_num)
task.forward_dashboard_event(self.dashboard, copy.deepcopy(data),
"status.running", job_num)
continue
Expand All @@ -1201,7 +1221,8 @@ def poll(self):
self.task.as_branch(b).complete()
for b in data["branches"]
):
finished_jobs.append(job_num)
finished_jobs.add(job_num)
self._existing_branches |= set(data["branches"])
self.poll_data.n_active -= 1
data["job_id"] = self.job_data.dummy_job_id
task.forward_dashboard_event(self.dashboard, copy.deepcopy(data),
Expand All @@ -1225,16 +1246,16 @@ def poll(self):
self.poll_data.n_active -= 1

# retry or ultimately failed?
if self.job_retries[job_num] < task.retries:
self.job_retries[job_num] += 1
if self._job_retries[job_num] < task.retries:
self._job_retries[job_num] += 1
self.job_data.attempts.setdefault(job_num, 0)
self.job_data.attempts[job_num] += 1
data["status"] = self.job_manager.RETRY
retry_jobs.append(job_num)
retry_jobs.add(job_num)
task.forward_dashboard_event(self.dashboard, copy.deepcopy(data),
"status.retry", job_num)
else:
failed_jobs.append(job_num)
failed_jobs.add(job_num)
task.forward_dashboard_event(self.dashboard, copy.deepcopy(data),
"status.failed", job_num)
continue
Expand Down Expand Up @@ -1264,7 +1285,7 @@ def poll(self):
task.publish_progress(100.0 * n_finished / n_jobs)

# remove resources of finished and failed jobs
for job_num in finished_jobs + failed_jobs:
for job_num in finished_jobs | failed_jobs:
job_resources.pop(job_num, None)
# check if the maximum possible resources decreased and report to the scheduler
new_max_resources = self._maximum_resources(job_resources, self.poll_data.n_parallel)
Expand Down Expand Up @@ -1301,7 +1322,7 @@ def poll(self):

# complain when failed
if failed:
failed_nums = [job_num for job_num in failed_jobs if job_num not in retry_jobs]
failed_nums = sorted(failed_jobs - retry_jobs)
raise Exception(
"tolerance exceeded for job(s) {}".format(",".join(map(str, failed_nums))),
)
Expand Down

0 comments on commit 37592c9

Please sign in to comment.