Skip to content

Commit

Permalink
Addressing Final comments. (#57)
Browse files Browse the repository at this point in the history
- Added dag-run timeout.
- airflow related scheduling checks in decorator.
- Auto naming sensors if no name is provided
- Annotations to k8s operators
- fix: argument serialization for `DAG` arguments (method names refactored like `to_dict` became `serialize`)
- annotation bug fix
- setting`workflow-timeout` for only scheduled dags
  • Loading branch information
valayDave committed Jul 29, 2022
1 parent d3ad82d commit 8e68090
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 107 deletions.
22 changes: 22 additions & 0 deletions metaflow/plugins/airflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AirflowTask,
Workflow,
)
from metaflow import current

AIRFLOW_DEPLOY_TEMPLATE_FILE = os.path.join(os.path.dirname(__file__), "dag.py")

Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
worker_pool=None,
description=None,
file_path=None,
workflow_timeout=None,
is_paused_upon_creation=True,
):
self.name = name
Expand All @@ -100,6 +102,7 @@ def __init__(
_, self.graph_structure = self.graph.output_steps()
self.worker_pool = worker_pool
self.is_paused_upon_creation = is_paused_upon_creation
self.workflow_timeout = workflow_timeout
self._set_scheduling_interval()

def _set_scheduling_interval(self):
Expand Down Expand Up @@ -398,6 +401,20 @@ def _to_job(self, node):
)
)

annotations = {
"metaflow/owner": self.username,
"metaflow/user": self.username,
"metaflow/flow_name": self.flow.name,
}
if current.get("project_name"):
annotations.update(
{
"metaflow/project_name": current.project_name,
"metaflow/branch_name": current.branch_name,
"metaflow/project_flow_name": current.project_flow_name,
}
)

k8s_operator_args = dict(
# like argo workflows we use step_name as name of container
name=node.name,
Expand All @@ -415,6 +432,7 @@ def _to_job(self, node):
node, input_paths, self.code_package_url, user_code_retries
),
),
annotations=annotations,
image=k8s_deco.attributes["image"],
resources=resources,
execution_timeout=dict(seconds=runtime_limit),
Expand Down Expand Up @@ -594,6 +612,10 @@ def _visit(node, workflow, exit_node=None):
)
airflow_dag_args["is_paused_upon_creation"] = self.is_paused_upon_creation

# workflow timeout should only be enforced if a dag is scheduled.
if self.workflow_timeout is not None and self.schedule_interval is not None:
airflow_dag_args["dagrun_timeout"] = dict(seconds=self.workflow_timeout)

appending_sensors = self._collect_flow_sensors()
workflow = Workflow(
dag_id=self.name,
Expand Down
27 changes: 8 additions & 19 deletions metaflow/plugins/airflow/airflow_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,12 @@ def airflow(obj, name=None):
show_default=True,
help="Maximum number of parallel processes.",
)
# TODO: Enable workflow timeout.
# @click.option(
# "--workflow-timeout",
# default=None,
# type=int,
# help="Workflow timeout in seconds. Enforced only for scheduled DAGs.",
# )
@click.option(
"--workflow-timeout",
default=None,
type=int,
help="Workflow timeout in seconds. Enforced only for scheduled DAGs.",
)
@click.option(
"--worker-pool",
default=None,
Expand Down Expand Up @@ -123,7 +122,7 @@ def make_flow(
file,
):
# Validate if the workflow is correctly parsed.
# _validate_workflow(obj.flow, obj.graph, obj.flow_datastore, obj.metadata)
_validate_workflow(obj.flow, obj.graph, obj.flow_datastore, obj.metadata)

# Attach @kubernetes.
decorators._attach_decorators(obj.flow, [KubernetesDecorator.name])
Expand Down Expand Up @@ -157,26 +156,16 @@ def make_flow(
username=get_username(),
max_workers=max_workers,
worker_pool=worker_pool,
# workflow_timeout=workflow_timeout,
workflow_timeout=workflow_timeout,
description=obj.flow.__doc__,
file_path=file,
is_paused_upon_creation=is_paused_upon_creation,
)


# TODO: Clean this out
def _validate_workflow(flow, graph, flow_datastore, metadata):
# check for other compute related decorators.
# supported compute : k8s (v1), local(v2), batch(v3),
# todo : check for the flow level decorators are correctly set.
# TODO: Move the check to the decorator
schedule_interval = flow._flow_decorators.get("airflow_schedule_interval")
schedule = flow._flow_decorators.get("schedule")
if schedule is not None and schedule_interval is not None:
raise AirflowException(
"Flow cannot have @schedule and @airflow_schedule_interval at the same time. Use any one."
)
# This check can be handled by airflow.py
for node in graph:
if node.type == "foreach":
raise NotSupportedException(
Expand Down
8 changes: 8 additions & 0 deletions metaflow/plugins/airflow/airflow_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from metaflow.decorators import FlowDecorator, StepDecorator
from metaflow.metadata import MetaDatum
from .exception import AirflowException

from .airflow_utils import TASK_ID_XCOM_KEY, AirflowTask, SensorNames

Expand Down Expand Up @@ -54,6 +55,13 @@ class AirflowScheduleIntervalDecorator(FlowDecorator):
def flow_init(
self, flow, graph, environment, flow_datastore, metadata, logger, echo, options
):
schedule_interval = flow._flow_decorators.get("airflow_schedule_interval")
schedule = flow._flow_decorators.get("schedule")
if schedule is not None and schedule_interval is not None:
raise AirflowException(
"Flow cannot have @schedule and @airflow_schedule_interval at the same time. Use any one."
)

self._option_values = options

if self._option_values["schedule"]:
Expand Down
104 changes: 41 additions & 63 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,65 +63,57 @@ def json_dump(val):
return json.dumps(val)


# TODO : (savin-comments) Fix serialization of args :
class AirflowDAGArgs(object):
# _arg_types This object helps map types of
# different keys that need to be parsed. None of the "values" in this
# dictionary are being used. But the "types" of the values of are used when
# reparsing the arguments from the config variable.

# TODO: These values are being overriden in airflow.py. Can we list the
# sensible defaults directly here so that it is easier to grok the code.
# `_arg_types` is a dictionary which represents the types of the arguments of an Airflow `DAG`.
# `_arg_types` is used when parsing types back from the configuration json.
# It doesn't cover all the arguments but covers many of the important one which can come from the cli.
_arg_types = {
"dag_id": "asdf",
"description": "asdfasf",
"schedule_interval": "*/2 * * * *",
"start_date": datetime.now(),
"catchup": False,
"tags": [],
"max_retry_delay": "",
"dagrun_timeout": timedelta(minutes=60 * 4),
"dag_id": str,
"description": str,
"schedule_interval": str,
"start_date": datetime,
"catchup": bool,
"tags": list,
"dagrun_timeout": timedelta,
"default_args": {
"owner": "some_username",
"depends_on_past": False,
"email": ["some_email"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(seconds=10),
"queue": "bash_queue", # which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues.
"pool": "backfill", # the slot pool this task should run in, slot pools are a way to limit concurrency for certain tasks
"priority_weight": 10,
"wait_for_downstream": False,
"sla": timedelta(hours=2),
"execution_timeout": timedelta(minutes=10),
"trigger_rule": "all_success",
"owner": str,
"depends_on_past": bool,
"email": list,
"email_on_failure": bool,
"email_on_retry": bool,
"retries": int,
"retry_delay": timedelta,
"queue": str, # which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues.
"pool": str, # the slot pool this task should run in, slot pools are a way to limit concurrency for certain tasks
"priority_weight": int,
"wait_for_downstream": bool,
"sla": timedelta,
"execution_timeout": timedelta,
"trigger_rule": str,
},
}

metaflow_specific_args = {
# Reference for user_defined_filters : https://stackoverflow.com/a/70175317
"user_defined_filters": dict(
task_id_creator=lambda v: task_id_creator(v),
json_dump=lambda val: json_dump(val),
run_id_creator=lambda val: run_id_creator(val),
),
}
# Reference for user_defined_filters : https://stackoverflow.com/a/70175317
filters = dict(
task_id_creator=lambda v: task_id_creator(v),
json_dump=lambda val: json_dump(val),
run_id_creator=lambda val: run_id_creator(val),
)

def __init__(self, **kwargs):
self._args = kwargs

@property
def arguments(self):
return dict(**self._args, **self.metaflow_specific_args)
return dict(**self._args, user_defined_filters=self.filters)

# just serialize?
def _serialize_args(self):
def serialize(self):
def parse_args(dd):
data_dict = {}
for k, v in dd.items():
# see the comment below for `from_dict`
if k == "default_args":
if isinstance(v, dict):
data_dict[k] = parse_args(v)
elif isinstance(v, datetime):
data_dict[k] = v.isoformat()
Expand All @@ -133,38 +125,28 @@ def parse_args(dd):

return parse_args(self._args)

# just deserialize?
@classmethod
def from_dict(cls, data_dict):
def deserialize(cls, data_dict):
def parse_args(dd, type_check_dict):
kwrgs = {}
for k, v in dd.items():
if k not in type_check_dict:
kwrgs[k] = v
continue
# wouldn't you want to do this parsing for any type of nested structure
# that is not datetime or timedelta? that should remove the reliance on
# the magic word - default_args
if k == "default_args":
if isinstance(v, dict) and isinstance(type_check_dict[k], dict):
kwrgs[k] = parse_args(v, type_check_dict[k])
elif isinstance(type_check_dict[k], datetime):
elif type_check_dict[k] == datetime:
kwrgs[k] = datetime.fromisoformat(v)
elif isinstance(type_check_dict[k], timedelta):
elif type_check_dict[k] == timedelta:
kwrgs[k] = timedelta(**v)
else:
kwrgs[k] = v
return kwrgs

return cls(**parse_args(data_dict, cls._arg_types))

def to_dict(self):
# dd is quite cryptic. why not just return self._serialize? also do we even need
# this method? how about we just use `serialize`?
dd = self._serialize_args()
return dd


def _kubernetes_pod_operator_args(flow_name, step_name, operator_args):
def _kubernetes_pod_operator_args(operator_args):
from kubernetes import client

from airflow.kubernetes.secret import Secret
Expand All @@ -179,8 +161,6 @@ def _kubernetes_pod_operator_args(flow_name, step_name, operator_args):
"secrets": secrets,
# Question for (savin):
# Default timeout in airflow is 120. I can remove `startup_timeout_seconds` for now. how should we expose it to the user?
# todo :annotations are not empty. see @kubernetes or argo-workflows
"annotations": {},
}
)
# Below cannot be passed in dictionary form. After trying a few times it didin't work.
Expand Down Expand Up @@ -310,9 +290,7 @@ def _kubenetes_task(self):
"Install the Airflow Kubernetes provider using : "
"`pip install apache-airflow-providers-cncf-kubernetes`"
)
k8s_args = _kubernetes_pod_operator_args(
self._flow_name, self.name, self._operator_args
)
k8s_args = _kubernetes_pod_operator_args(self._operator_args)
return KubernetesPodOperator(**k8s_args)

def to_task(self):
Expand Down Expand Up @@ -341,7 +319,7 @@ def to_dict(self):
return dict(
graph_structure=self.graph_structure,
states={s: v.to_dict() for s, v in self.states.items()},
dag_instantiation_params=self._dag_instantiation_params.to_dict(),
dag_instantiation_params=self._dag_instantiation_params.serialize(),
file_path=self._file_path,
metaflow_params=self.metaflow_params,
)
Expand All @@ -355,7 +333,7 @@ def from_dict(cls, data_dict):
file_path=data_dict["file_path"],
graph_structure=data_dict["graph_structure"],
)
re_cls._dag_instantiation_params = AirflowDAGArgs.from_dict(
re_cls._dag_instantiation_params = AirflowDAGArgs.deserialize(
data_dict["dag_instantiation_params"]
)

Expand Down
44 changes: 19 additions & 25 deletions metaflow/plugins/airflow/sensors/base_sensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from metaflow.decorators import FlowDecorator
from ..exception import AirflowException
from ..airflow_utils import AirflowTask
from ..airflow_utils import AirflowTask, task_id_creator


class AirflowSensorDecorator(FlowDecorator):
Expand All @@ -25,9 +26,9 @@ class AirflowSensorDecorator(FlowDecorator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO : (savin-comments) : refactor the name of `self._task_name` to have a common name.
# Is the task is task name a metaflow task?
self._task_name = self.operator_type
self._airflow_task_name = None
self._id = str(uuid.uuid4())

def serialize_operator_args(self):
"""
Expand All @@ -45,7 +46,7 @@ def serialize_operator_args(self):
def create_task(self):
task_args = self.serialize_operator_args()
return AirflowTask(
self._task_name,
self._airflow_task_name,
operator_type=self.operator_type,
).set_operator_args(**{k: v for k, v in task_args.items() if v is not None})

Expand All @@ -54,27 +55,20 @@ def compile(self):
compile the arguments for `airflow create` command.
This will even check if the arguments are acceptible.
"""
# If there are more than one sensor decorators then ensure that `name` is set
# so that we can have uniqueness of `task_id` when creating tasks on airflow.
sensor_decorators = [
d
for d in self._flow_decorators
if issubclass(d.__class__, AirflowSensorDecorator)
]
sensor_deco_types = {}
for d in sensor_decorators:
if d.__class__.__name__ not in sensor_deco_types:
sensor_deco_types[d.__class__.__name__] = []
sensor_deco_types[d.__class__.__name__].append(d)
# If there are more than one decorator per sensor-type then we require the name argument.
if sum([len(v) for v in sensor_deco_types.values()]) > len(sensor_deco_types):
if self.attributes["name"] is None:
# TODO : (savin-comments) autogenerate this name
raise AirflowException(
"`name` argument cannot be `None` when multiple Airflow Sensor related decorators are attached to a flow."
)
if self.attributes["name"] is not None:
self._task_name = self.attributes["name"]
# If there is no name set then auto-generate the name. This is done because there can be more than
# one `AirflowSensorDecorator` of the same type.
if self.attributes["name"] is None:
deco_index = [
d._id
for d in self._flow_decorators
if issubclass(d.__class__, AirflowSensorDecorator)
].index(self._id)
self._airflow_task_name = "%s-%s" % (
self.operator_type,
task_id_creator([self.operator_type, str(deco_index)]),
)
else:
self._airflow_task_name = self.attributes["name"]

def flow_init(
self, flow, graph, environment, flow_datastore, metadata, logger, echo, options
Expand Down

0 comments on commit 8e68090

Please sign in to comment.