Skip to content

Commit

Permalink
setting retry logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
valayDave committed Mar 19, 2022
1 parent a697b56 commit e2a1e50
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 28 deletions.
50 changes: 28 additions & 22 deletions metaflow/plugins/airflow/airflow_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from metaflow import R
import sys
from metaflow.util import compress_list, dict_to_cli_options, to_pascalcase
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
import os
from metaflow.mflog import capture_output_to_mflog
import random
Expand Down Expand Up @@ -36,9 +37,9 @@

class Airflow(object):

# todo : consistency gaurentee across retries.
# IE task_id doesn't change when new retry is done.
# Check if task_id also works
# {{ ti.job_id }} is doesn't provide the gaurentees we need for `Task` ids.
# {{ ti.job_id }} changes when retry is done.

task_id = "arf-{{ ti.job_id }}"
task_id_arg = "--task-id %s" % task_id
# Airflow run_ids are of the form : "manual__2022-03-15T01:26:41.186781+00:00"
Expand Down Expand Up @@ -107,8 +108,10 @@ def _k8s_job(self, node, input_paths, env):
# todo : check for retry
# since we are attaching k8s at cli, there will be one for a step.
k8s_deco = [deco for deco in node.decorators if deco.name == "kubernetes"][0]

user_code_retries, total_retries = self._get_retries(node)
retry_delay = self._get_retry_delay(node)
runtime_limit = get_run_time_limit_for_task(node.decorators)

return create_k8s_args(
self.flow_datastore,
self.metadata,
Expand All @@ -130,7 +133,9 @@ def _k8s_job(self, node, input_paths, env):
gpu=k8s_deco.attributes["gpu"],
disk=k8s_deco.attributes["disk"],
memory=k8s_deco.attributes["memory"],
run_time_limit=None, # todo fix
retries=total_retries,
run_time_limit=timedelta(seconds=runtime_limit),
retry_delay=retry_delay,
env=env,
user=util.get_username(),
)
Expand All @@ -147,6 +152,13 @@ def _get_retries(self, node):

return max_user_code_retries, max_user_code_retries + max_error_retries

def _get_retry_delay(self, node):
retry_decos = [deco for deco in node.decorators if deco.name == "retry"]
if len(retry_decos) > 0:
retry_mins = retry_decos[0]["attributes"]["minutes_between_retries"]
return timedelta(minutes=retry_mins)
return None

def _process_parameters(self):
# Copied from metaflow.plugins.aws.step_functions.step_functions
parameters = []
Expand Down Expand Up @@ -322,8 +334,12 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
top_opts_dict.update(deco.get_top_level_options())

top_opts = list(dict_to_cli_options(top_opts_dict))
join_in_foreach = node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach"
any_previous_node_is_foreach = any(self.graph[n].type == "foreach" for n in node.in_funcs)
join_in_foreach = (
node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach"
)
any_previous_node_is_foreach = any(
self.graph[n].type == "foreach" for n in node.in_funcs
)

top_level = top_opts + [
"--quiet",
Expand Down Expand Up @@ -488,23 +504,13 @@ def _create_airflow_file(self, json_dag):
def _create_defaults(self):
return {
"owner": get_username(),
# If set on a task, doesn’t run the task in the current DAG run if the previous run of the task has failed.
"depends_on_past": False,
"email": [] if self.email is None else [self.email],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=5),
# 'queue': 'bash_queue',
# 'pool': 'backfill',
# 'priority_weight': 10,
# 'end_date': datetime(2016, 1, 1),
# 'wait_for_downstream': False,
# 'dag': dag,
# 'sla': timedelta(hours=2),
# 'execution_timeout': timedelta(seconds=300),
# 'on_failure_callback': some_function,
# 'on_success_callback': some_other_function,
# 'on_retry_callback': another_function,
# 'sla_miss_callback': yet_another_function,
# 'trigger_rule': 'all_success'
"retries": 0,
"execution_timeout": timedelta(days=5),
# check https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/models/baseoperator/index.html?highlight=retry_delay#airflow.models.baseoperator.BaseOperatorMeta
"retry_delay": timedelta(seconds=5),
}
27 changes: 23 additions & 4 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ def hasher(my_value):


class AirflowDAGArgs(object):
# _arg_types This object helps map types of
# different keys that need to be parsed. None of the values in this
# dictionary are being used.
_arg_types = {
"dag_id": "asdf",
"description": "asdfasf",
"schedule_interval": "*/2 * * * *",
"start_date": datetime.now(),
"catchup": False,
"tags": [],
"max_retry_delay": "",
"dagrun_timeout": timedelta(minutes=60 * 4),
"default_args": {
"owner": "some_username",
Expand All @@ -51,9 +55,10 @@ class AirflowDAGArgs(object):
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=5),
"queue": "bash_queue",
"pool": "backfill",
# Todo : find defaults
"retry_delay": timedelta(seconds=10),
"queue": "bash_queue", # which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues.
"pool": "backfill", # the slot pool this task should run in, slot pools are a way to limit concurrency for certain tasks
"priority_weight": 10,
"wait_for_downstream": False,
"sla": timedelta(hours=2),
Expand Down Expand Up @@ -146,7 +151,9 @@ def generate_rfc1123_name(flow_name, step_name):
def set_k8s_operator_args(flow_name, step_name, operator_args):
from kubernetes import client

task_id = "arf-{{ ti.job_id }}"
task_id = (
"arf-{{ ti.job_id }}" # Todo : find a way to switch this with something else.
)
run_id = "arf-{{ run_id | hash }}" # hash is added via the `user_defined_filters`
attempt = "{{ task_instance.try_number - 1 }}"
# Set dynamic env variables like run-id, task-id etc from here.
Expand Down Expand Up @@ -201,6 +208,8 @@ def set_k8s_operator_args(flow_name, step_name, operator_args):
"memory": operator_args.get("memory", "2000M"),
}
), # kubernetes.client.models.v1_resource_requirements.V1ResourceRequirements
"retries": operator_args.get("retries", 0), # Base operator command
"retry_exponential_backoff": False, # todo : should this be a arg we allow on CLI.
"affinity": None, # kubernetes.client.models.v1_affinity.V1Affinity
"config_file": None,
"node_selectors": {}, # todo : Find difference between "node_selectors" / "node_selector"
Expand Down Expand Up @@ -231,6 +240,16 @@ def set_k8s_operator_args(flow_name, step_name, operator_args):
"configmaps": None, # todo : find out what this will do ?
}
args["labels"].update(labels)
if operator_args.get("execution_timeout", None):
args["execution_timeout"] = (
timedelta(
**operator_args.get(
"execution_timeout",
)
),
)
if operator_args.get("retry_delay", None):
args["retry_delay"] = timedelta(**operator_args.get("retry_delay"))
return args


Expand Down
9 changes: 7 additions & 2 deletions metaflow/plugins/airflow/compute/k8s.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
import json
from metaflow.plugins.aws.eks.kubernetes import (
Kubernetes,
Expand Down Expand Up @@ -38,7 +39,9 @@ def create_k8s_args(
gpu=None,
disk=None,
memory=None,
run_time_limit=None,
run_time_limit=timedelta(days=5),
retries=None,
retry_delay=None,
env={},
user=None,
):
Expand Down Expand Up @@ -91,7 +94,9 @@ def create_k8s_args(
cpu=cpu,
memory=memory,
disk=disk,
timeout_in_seconds=run_time_limit,
execution_timeout=dict(seconds=run_time_limit.total_seconds()),
retry_delay=dict(seconds=retry_delay.total_seconds()) if retry_delay else None,
retries=retries,
env_vars=[dict(name=k, value=v) for k, v in env.items()],
labels=labels,
is_delete_operator_pod=True,
Expand Down

0 comments on commit e2a1e50

Please sign in to comment.