Skip to content

Commit

Permalink
Max workers and worker pool support.
Browse files Browse the repository at this point in the history
  • Loading branch information
valayDave committed Apr 7, 2022
1 parent 9c973f2 commit 5c97b15
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
36 changes: 23 additions & 13 deletions metaflow/plugins/airflow/airflow_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def airflow(ctx):
pass


def make_flow(obj, tags, namespace, worker_pools, is_project, file_path=None):
def make_flow(
obj, tags, namespace, max_workers, is_project, file_path=None, worker_pool=None
):
# Attach K8s decorator over here.
# todo This will be affected in the future based on how many compute providers are supported on Airflow.
decorators._attach_decorators(obj.flow, [KubernetesDecorator.name])
Expand All @@ -45,7 +47,8 @@ def make_flow(obj, tags, namespace, worker_pools, is_project, file_path=None):
obj.monitor,
tags=tags,
namespace=namespace,
max_workers=worker_pools,
max_workers=max_workers,
worker_pool=worker_pool,
username=get_username(),
is_project=is_project,
description=obj.flow.__doc__,
Expand All @@ -70,35 +73,42 @@ def make_flow(obj, tags, namespace, worker_pools, is_project, file_path=None):
default=None,
)
@click.option(
"--only-json",
is_flag=True,
default=False,
help="Only print out JSON",
"--max-workers",
default=100,
show_default=True,
help="Maximum number of concurrent airflow tasks to run for the DAG. ",
)
@click.option(
"--worker-pools",
default=100,
"--worker-pool",
default=None,
show_default=True,
help="Worker pool the for the airflow tasks."
)
@click.pass_obj
def create(
obj,
file_path,
tags=None,
user_namespace=None,
only_json=False,
worker_pools=None,
max_workers=None,
worker_pool=None,
):
flow = make_flow(
obj, tags, user_namespace, worker_pools, False, file_path=file_path
obj,
tags,
user_namespace,
max_workers,
False,
file_path=file_path,
worker_pool=worker_pool,
)
compiled_dag_file = flow.compile()
if file_path is None:
obj.echo_always(compiled_dag_file)
else:
if file_path.startswith('s3://'):
if file_path.startswith("s3://"):
with S3() as s3:
s3.put(file_path,compiled_dag_file)
s3.put(file_path, compiled_dag_file)
else:
with open(file_path, "w") as f:
f.write(compiled_dag_file)
14 changes: 13 additions & 1 deletion metaflow/plugins/airflow/airflow_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
username=None,
max_workers=None,
is_project=False,
worker_pool=None,
email=None,
start_date=datetime.now(),
description=None,
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
self._file_path = file_path
self.metaflow_parameters = None
_, self.graph_structure = self.graph.output_steps()
self.worker_pool = worker_pool

def _get_schedule(self):
schedule = self.flow._flow_decorators.get("schedule")
Expand Down Expand Up @@ -540,6 +542,11 @@ def _visit(node: DAGNode, workflow: Workflow, exit_node=None):
)
return workflow

# set max active tasks here , For more info check here :
# https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/models/dag/index.html#airflow.models.dag.DAG
other_args = (
{} if self.max_workers is None else dict(max_active_tasks=self.max_workers)
)
workflow = Workflow(
dag_id=self.name,
default_args=self._create_defaults(),
Expand All @@ -550,6 +557,7 @@ def _visit(node: DAGNode, workflow: Workflow, exit_node=None):
tags=self.tags,
file_path=self._file_path,
graph_structure=self.graph_structure,
**other_args
)
workflow = _visit(self.graph["start"], workflow)
workflow.set_parameters(self.metaflow_parameters)
Expand All @@ -571,7 +579,7 @@ def _create_airflow_file(self, json_dag):
)

def _create_defaults(self):
return {
defu_ = {
"owner": get_username(),
# If set on a task, doesn’t run the task in the current DAG run if the previous run of the task has failed.
"depends_on_past": False,
Expand All @@ -583,3 +591,7 @@ def _create_defaults(self):
# check https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/models/baseoperator/index.html?highlight=retry_delay#airflow.models.baseoperator.BaseOperatorMeta
"retry_delay": timedelta(seconds=5),
}
if self.worker_pool is not None:
defu_["pool"] = self.worker_pool

return defu_

0 comments on commit 5c97b15

Please sign in to comment.