Skip to content

Commit

Permalink
Refactored parameter macro settings. (valayDave#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
valayDave committed Jul 28, 2022
1 parent a3a4950 commit 0ffc813
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 50 deletions.
57 changes: 20 additions & 37 deletions metaflow/plugins/airflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,19 @@
from .exception import AirflowException
from .sensors import SUPPORTED_SENSORS
from .airflow_utils import (
AIRFLOW_TASK_ID,
RUN_HASH_ID_LEN,
RUN_ID_PREFIX,
TASK_ID_XCOM_KEY,
AirflowTask,
Workflow,
AIRFLOW_MACROS,
)
from metaflow import current

AIRFLOW_DEPLOY_TEMPLATE_FILE = os.path.join(os.path.dirname(__file__), "dag.py")


class Airflow(object):

parameter_macro = "{{ params | json_dump }}"
task_id = AIRFLOW_TASK_ID
task_id_arg = "--task-id %s" % task_id

# Airflow run_ids are of the form : "manual__2022-03-15T01:26:41.186781+00:00"
# Such run-ids break the `metaflow.util.decompress_list`; this is why we hash the runid
run_id = (
"%s-$(echo -n {{ run_id }}-{{ dag_run.dag_id }} | md5sum | awk '{print $1}' | awk '{print substr ($0, 0, %s)}')"
% (RUN_ID_PREFIX, str(RUN_HASH_ID_LEN))
)
# We do echo -n because emits line breaks and we dont want to consider that since it we want same hash value when retrieved in python.
run_id_arg = "--run-id %s" % run_id
attempt = "{{ task_instance.try_number - 1 }}"

def __init__(
self,
name,
Expand Down Expand Up @@ -226,12 +211,12 @@ def _make_input_path_compressed(
):
"""
This function is meant to compress the input paths and it specifically doesn't use
`metaflow.util.compress_list` under the hood. The reason is because the `self.run_id` is a complicated macro string
`metaflow.util.compress_list` under the hood. The reason is because the `AIRFLOW_MACROS.RUN_ID_SHELL` is a complicated macro string
that doesn't behave nicely with `metaflow.util.decompress_list` since the `decompress_util`
function expects a string which doesn't contain any delimiter characters and the run-id string does.
Hence we have a custom compression string created via `_make_input_path_compressed` function instead of `compress_list`.
"""
return "%s:" % (self.run_id) + ",".join(
return "%s:" % (AIRFLOW_MACROS.RUN_ID_SHELL) + ",".join(
self._make_input_path(s, only_task_id=True) for s in step_names
)

Expand All @@ -248,7 +233,7 @@ def _make_input_path(self, step_name, only_task_id=False):
if only_task_id:
return task_id_string

return "%s%s" % (self.run_id, task_id_string)
return "%s%s" % (AIRFLOW_MACROS.RUN_ID_SHELL, task_id_string)

def _to_job(self, node):
"""
Expand Down Expand Up @@ -276,7 +261,7 @@ def _to_job(self, node):
# parameters.

if len(self.parameters):
env["METAFLOW_PARAMETERS"] = self.parameter_macro
env["METAFLOW_PARAMETERS"] = AIRFLOW_MACROS.PARAMETER
input_paths = None
else:
# If it is not the start node then we check if there are many paths
Expand Down Expand Up @@ -328,11 +313,9 @@ def _to_job(self, node):
k8s = Kubernetes(self.flow_datastore, self.metadata, self.environment)
user = util.get_username()

airflow_task_id = AIRFLOW_TASK_ID
mf_run_id = (
"%s-{{ [run_id, dag_run.dag_id] | run_id_creator }}" % RUN_ID_PREFIX
) # run_id_creator is added via the `user_defined_filters`
attempt = "{{ task_instance.try_number - 1 }}"
airflow_task_id = AIRFLOW_MACROS.TASK_ID
mf_run_id = AIRFLOW_MACROS.RUN_ID
attempt = AIRFLOW_MACROS.ATTEMPT
labels = {
"app": "metaflow",
"app.kubernetes.io/name": "metaflow-task",
Expand All @@ -358,8 +341,8 @@ def _to_job(self, node):
"METAFLOW_CARD_S3ROOT": DATASTORE_CARD_S3ROOT,
"METAFLOW_RUN_ID": mf_run_id,
"METAFLOW_AIRFLOW_TASK_ID": airflow_task_id,
"METAFLOW_AIRFLOW_DAG_RUN_ID": "{{run_id}}",
"METAFLOW_AIRFLOW_JOB_ID": "{{ti.job_id}}",
"METAFLOW_AIRFLOW_DAG_RUN_ID": AIRFLOW_MACROS.AIRFLOW_RUN_ID,
"METAFLOW_AIRFLOW_JOB_ID": AIRFLOW_MACROS.AIRFLOW_JOB_ID,
"METAFLOW_ATTEMPT_NUMBER": attempt,
}
env.update(additional_mf_variables)
Expand Down Expand Up @@ -416,10 +399,10 @@ def _to_job(self, node):
node_selector=k8s_deco.attributes["node_selector"],
cmds=k8s._command(
self.flow.name,
self.run_id,
AIRFLOW_MACROS.RUN_ID_SHELL,
node.name,
self.task_id,
self.attempt,
AIRFLOW_MACROS.TASK_ID,
AIRFLOW_MACROS.ATTEMPT,
code_package_url=self.code_package_url,
step_cmds=self._step_cli(
node, input_paths, self.code_package_url, user_code_retries
Expand Down Expand Up @@ -492,7 +475,7 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):

if node.name == "start":
# We need a separate unique ID for the special _parameters task
task_id_params = "%s-params" % self.task_id
task_id_params = "%s-params" % AIRFLOW_MACROS.TASK_ID
# Export user-defined parameters into runtime environment
param_file = "".join(
random.choice(string.ascii_lowercase) for _ in range(10)
Expand All @@ -509,7 +492,7 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
+ top_level
+ [
"init",
self.run_id_arg,
"--run-id %s" % AIRFLOW_MACROS.RUN_ID_SHELL,
"--task-id %s" % task_id_params,
]
)
Expand All @@ -525,7 +508,7 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
# Dump the parameters task
"dump",
"--max-value-size=0",
"%s/_parameters/%s" % (self.run_id, task_id_params),
"%s/_parameters/%s" % (AIRFLOW_MACROS.RUN_ID_SHELL, task_id_params),
]
cmd = "if ! %s >/dev/null 2>/dev/null; then %s && %s; fi" % (
" ".join(exists),
Expand All @@ -534,14 +517,14 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
)
cmds.append(cmd)
# set input paths for parameters
paths = "%s/_parameters/%s" % (self.run_id, task_id_params)
paths = "%s/_parameters/%s" % (AIRFLOW_MACROS.RUN_ID_SHELL, task_id_params)

step = [
"step",
node.name,
self.run_id_arg,
self.task_id_arg,
"--retry-count %s" % self.attempt,
"--run-id %s" % AIRFLOW_MACROS.RUN_ID_SHELL,
"--task-id %s" % AIRFLOW_MACROS.TASK_ID,
"--retry-count %s" % AIRFLOW_MACROS.ATTEMPT,
"--max-user-code-retries %d" % user_code_retries,
"--input-paths %s" % paths,
]
Expand Down
57 changes: 44 additions & 13 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,38 @@ class AirflowSensorNotFound(Exception):
TASK_ID_HASH_LEN = 8
RUN_ID_PREFIX = "airflow"

# AIRFLOW_TASK_ID will work for linear/branched workflows.
# ti.task_id is the stepname in metaflow code.
# AIRFLOW_TASK_ID uses a jinja filter called `task_id_creator` which helps
# concatenate the string using a `/`. Since run-id will keep changing and stepname will be
# the same task id will change. Since airflow doesn't encourage dynamic rewriting of dags
# we can rename steps in a foreach with indexes (eg. `stepname-$index`) to create those steps.
# Hence : Foreachs will require some special form of plumbing.
# https://stackoverflow.com/questions/62962386/can-an-airflow-task-dynamically-generate-a-dag-at-runtime
AIRFLOW_TASK_ID = (
"%s-{{ [run_id, ti.task_id, dag_run.dag_id ] | task_id_creator }}" % RUN_ID_PREFIX
)

class AIRFLOW_MACROS:
# run_id_creator is added via the `user_defined_filters`
RUN_ID = "%s-{{ [run_id, dag_run.dag_id] | run_id_creator }}" % RUN_ID_PREFIX
PARAMETER = "{{ params | json_dump }}"

# AIRFLOW_MACROS.TASK_ID will work for linear/branched workflows.
# ti.task_id is the stepname in metaflow code.
# AIRFLOW_MACROS.TASK_ID uses a jinja filter called `task_id_creator` which helps
# concatenate the string using a `/`. Since run-id will keep changing and stepname will be
# the same task id will change. Since airflow doesn't encourage dynamic rewriting of dags
# we can rename steps in a foreach with indexes (eg. `stepname-$index`) to create those steps.
# Hence : Foreachs will require some special form of plumbing.
# https://stackoverflow.com/questions/62962386/can-an-airflow-task-dynamically-generate-a-dag-at-runtime
TASK_ID = (
"%s-{{ [run_id, ti.task_id, dag_run.dag_id ] | task_id_creator }}"
% RUN_ID_PREFIX
)

# Airflow run_ids are of the form : "manual__2022-03-15T01:26:41.186781+00:00"
# Such run-ids break the `metaflow.util.decompress_list`; this is why we hash the runid
# We do echo -n because emits line breaks and we dont want to consider that since it we want same hash value when retrieved in python.
RUN_ID_SHELL = (
"%s-$(echo -n {{ run_id }}-{{ dag_run.dag_id }} | md5sum | awk '{print $1}' | awk '{print substr ($0, 0, %s)}')"
% (RUN_ID_PREFIX, str(RUN_HASH_ID_LEN))
)

ATTEMPT = "{{ task_instance.try_number - 1 }}"

AIRFLOW_RUN_ID = "{{ run_id }}"

AIRFLOW_JOB_ID = "{{ ti.job_id }}"


class SensorNames:
Expand All @@ -46,6 +67,16 @@ def get_supported_sensors(cls):
return list(cls.__dict__.values())


def run_id_creator(val):
# join `[dag-id,run-id]` of airflow dag.
return hashlib.md5("-".join(val).encode("utf-8")).hexdigest()[:RUN_HASH_ID_LEN]


def task_id_creator(val):
# join `[dag-id,run-id]` of airflow dag.
return hashlib.md5("-".join(val).encode("utf-8")).hexdigest()[:TASK_ID_HASH_LEN]


def id_creator(val, hash_len):
# join `[dag-id,run-id]` of airflow dag.
return hashlib.md5("-".join(val).encode("utf-8")).hexdigest()[:hash_len]
Expand Down Expand Up @@ -88,9 +119,9 @@ class AirflowDAGArgs(object):

# Reference for user_defined_filters : https://stackoverflow.com/a/70175317
filters = dict(
task_id_creator=lambda v: id_creator(v, TASK_ID_HASH_LEN),
task_id_creator=lambda v: task_id_creator(v),
json_dump=lambda val: json_dump(val),
run_id_creator=lambda val: id_creator(val, RUN_HASH_ID_LEN),
run_id_creator=lambda val: run_id_creator(val),
)

def __init__(self, **kwargs):
Expand Down

0 comments on commit 0ffc813

Please sign in to comment.