Skip to content

Commit

Permalink
feat: use taskflow for concurrent task runs
Browse files Browse the repository at this point in the history
  • Loading branch information
ividito authored and vishal committed Oct 28, 2024
1 parent 0df795e commit bfce3d5
Show file tree
Hide file tree
Showing 10 changed files with 343 additions and 213 deletions.
79 changes: 29 additions & 50 deletions dags/veda_data_pipeline/groups/discover_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,29 @@
from airflow.models.variable import Variable
from airflow.models.xcom import LazyXComAccess
from airflow.operators.dummy_operator import DummyOperator as EmptyOperator
from airflow.decorators import task_group
from airflow.decorators import task_group, task
from airflow.models.baseoperator import chain
from airflow.operators.python import BranchPythonOperator, PythonOperator, ShortCircuitOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow_multi_dagrun.operators import TriggerMultiDagRunOperator
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
from veda_data_pipeline.utils.s3_discovery import (
s3_discovery_handler, EmptyFileListError
)
from veda_data_pipeline.groups.processing_group import subdag_process
from veda_data_pipeline.groups.processing_tasks import build_stac_kwargs, submit_to_stac_ingestor_task


group_kwgs = {"group_id": "Discover", "tooltip": "Discover"}


def discover_from_s3_task(ti, event={}, **kwargs):
@task
def discover_from_s3_task(ti=None, event={}, **kwargs):
"""Discover grouped assets/files from S3 in batches of 2800. Produce a list of such files stored on S3 to process.
This task is used as part of the discover_group subdag and outputs data to EVENT_BUCKET.
"""
config = {
**event,
**ti.dag_run.conf,
}
# TODO test that this context var is available in taskflow
last_successful_execution = kwargs.get("prev_start_date_success")
if event.get("schedule") and last_successful_execution:
config["last_successful_execution"] = last_successful_execution.isoformat()
Expand All @@ -43,17 +45,15 @@ def discover_from_s3_task(ti, event={}, **kwargs):
)
except EmptyFileListError as ex:
print(f"Received an exception {ex}")
return []
# TODO replace short circuit
return {}

@task
def get_files_to_process(**kwargs):
def get_files_to_process(payload, ti=None):
"""Get files from S3 produced by the discovery task.
Used as part of both the parallel_run_process_rasters and parallel_run_process_vectors tasks.
"""
ti = kwargs.get("ti")
dynamic_group_id = ti.task_id.split(".")[0]
payload = ti.xcom_pull(task_ids=f"{dynamic_group_id}.discover_from_s3")
if isinstance(payload, LazyXComAccess):
if isinstance(payload, LazyXComAccess): # if used as part of a dynamic task mapping
payloads_xcom = payload[0].pop("payload", [])
payload = payload[0]
else:
Expand All @@ -66,52 +66,31 @@ def get_files_to_process(**kwargs):
} for indx, payload_xcom in enumerate(payloads_xcom)]


def vector_raster_choice(ti):
"""Choose whether to process rasters or vectors based on the payload."""
payload = ti.dag_run.conf
dynamic_group_id = ti.task_id.split(".")[0]

if payload.get("vector"):
return f"{dynamic_group_id}.parallel_run_process_generic_vectors"
if payload.get("vector_eis"):
return f"{dynamic_group_id}.parallel_run_process_vectors"
return f"{dynamic_group_id}.parallel_run_process_rasters"

# this task group is defined for reference, but can not be used in expanded taskgroup maps
@task_group
def subdag_discover(event={}):
discover_from_s3 = ShortCircuitOperator(
task_id="discover_from_s3",
python_callable=discover_from_s3_task,
op_kwargs={"text": "Discover from S3", "event": event},
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True,
)
# Define operators for non-taskflow tasks
discover_from_s3 = discover_from_s3_task(event=event)

raster_vector_branching = BranchPythonOperator(
task_id="raster_vector_branching",
python_callable=vector_raster_choice,
submit_to_stac_ingestor = PythonOperator(
task_id="submit_to_stac_ingestor",
python_callable=submit_to_stac_ingestor_task,
)

# define DAG using taskflow notation
discover_from_s3 = discover_from_s3_task(event=event)
get_files = get_files_to_process()

chain(discover_from_s3, get_files)

build_stac_kwargs_task = build_stac_kwargs.expand(event=get_files)
build_stac = EcsRunTaskOperator.partial(
task_id="build_stac"
).expand_kwargs(build_stac_kwargs_task)

run_process_raster = subdag_process.expand(get_files_to_process())
submit_to_stac_ingestor.expand(build_stac)

# TODO don't let me merge this without spending more time with vector ingest
run_process_vector = TriggerMultiDagRunOperator(
task_id="parallel_run_process_vectors",
trigger_dag_id="veda_ingest_vector",
python_callable=get_files_to_process,
)

run_process_generic_vector = TriggerMultiDagRunOperator(
task_id="parallel_run_process_generic_vectors",
trigger_dag_id="veda_generic_ingest_vector",
python_callable=get_files_to_process,
)

# extra no-op, needed to run in dynamic mapping context
end_discover = EmptyOperator(task_id="end_discover", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,)

discover_from_s3 >> raster_vector_branching >> [run_process_raster, run_process_vector, run_process_generic_vector]
run_process_raster >> end_discover
run_process_vector >> end_discover
run_process_generic_vector >> end_discover

95 changes: 0 additions & 95 deletions dags/veda_data_pipeline/groups/processing_group.py

This file was deleted.

165 changes: 165 additions & 0 deletions dags/veda_data_pipeline/groups/processing_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import json
import logging

import smart_open
from airflow.models.variable import Variable
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
from airflow.decorators import task_group, task
from veda_data_pipeline.utils.submit_stac import submission_handler

group_kwgs = {"group_id": "Process", "tooltip": "Process"}


def log_task(text: str):
logging.info(text)

@task()
def submit_to_stac_ingestor_task(built_stac:str):
"""Submit STAC items to the STAC ingestor API."""
event = json.loads(built_stac)
success_file = event["payload"]["success_event_key"]
with smart_open.open(success_file, "r") as _file:
stac_items = json.loads(_file.read())

for item in stac_items:
submission_handler(
event=item,
endpoint="/ingestions",
cognito_app_secret=Variable.get("COGNITO_APP_SECRET"),
stac_ingestor_api_url=Variable.get("STAC_INGESTOR_API_URL"),
)
return event

@task
def build_stac_kwargs(event={}):
"""Build kwargs for the ECS operator."""
mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True)
if event:
intermediate = {
**event
} # this is dumb but it resolves the MappedArgument to a dict that can be JSON serialized
payload = json.dumps(intermediate)
else:
payload = "{{ task_instance.dag_run.conf }}"

return {
"overrides": {
"containerOverrides": [
{
"name": f"{mwaa_stack_conf.get('PREFIX')}-veda-stac-build",
"command": [
"/usr/local/bin/python",
"handler.py",
"--payload",
payload,
],
"environment": [
{
"name": "EXTERNAL_ROLE_ARN",
"value": Variable.get(
"ASSUME_ROLE_READ_ARN", default_var=""
),
},
{
"name": "BUCKET",
"value": "veda-data-pipelines-staging-lambda-ndjson-bucket",
},
{
"name": "EVENT_BUCKET",
"value": mwaa_stack_conf.get("EVENT_BUCKET"),
},
],
"memory": 2048,
"cpu": 1024,
},
],
},
"network_configuration": {
"awsvpcConfiguration": {
"securityGroups": mwaa_stack_conf.get("SECURITYGROUPS"),
"subnets": mwaa_stack_conf.get("SUBNETS"),
},
},
"awslogs_group": mwaa_stack_conf.get("LOG_GROUP_NAME"),
"awslogs_stream_prefix": f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-stac-build",
}

@task
def build_vector_kwargs(event={}):
"""Build kwargs for the ECS operator."""
mwaa_stack_conf = Variable.get(
"MWAA_STACK_CONF", default_var={}, deserialize_json=True
)
vector_ecs_conf = Variable.get(
"VECTOR_ECS_CONF", default_var={}, deserialize_json=True
)

if event:
intermediate = {
**event
}
payload = json.dumps(intermediate)
else:
payload = "{{ task_instance.dag_run.conf }}"

return {
"trigger_rule": "none_failed",
"cluster": f"{mwaa_stack_conf.get('PREFIX')}-cluster",
"task_definition": f"{mwaa_stack_conf.get('PREFIX')}-vector-tasks",
"launch_type": "FARGATE",
"do_xcom_push": True,
"execution_timeout": timedelta(minutes=120),
"overrides": {
"containerOverrides": [
{
"name": f"{mwaa_stack_conf.get('PREFIX')}-veda-vector_ingest",
"command": [
"/var/lang/bin/python",
"handler.py",
"--payload",
payload,
],
"environment": [
{
"name": "EXTERNAL_ROLE_ARN",
"value": Variable.get(
"ASSUME_ROLE_READ_ARN", default_var=None
),
},
{
"name": "AWS_REGION",
"value": mwaa_stack_conf.get("AWS_REGION"),
},
{
"name": "VECTOR_SECRET_NAME",
"value": Variable.get("VECTOR_SECRET_NAME"),
},
],
},
],
},
"network_configuration": {
"awsvpcConfiguration": {
"securityGroups": vector_ecs_conf.get("VECTOR_SECURITY_GROUP"),
"subnets": vector_ecs_conf.get("VECTOR_SUBNETS"),
},
},
"awslogs_group": mwaa_stack_conf.get("LOG_GROUP_NAME"),
"awslogs_stream_prefix": f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-vector_ingest",
}


@task_group
def subdag_process(event={}):

build_stac = EcsRunTaskOperator.partial(
task_id="build_stac"
).expand_kwargs(build_stac_kwargs(event=event))

submit_to_stac_ingestor = PythonOperator(
task_id="submit_to_stac_ingestor",
python_callable=submit_to_stac_ingestor_task,
)

build_stac >> submit_to_stac_ingestor
Loading

0 comments on commit bfce3d5

Please sign in to comment.