Skip to content

Commit

Permalink
Advanced Params using json-schema
Browse files Browse the repository at this point in the history
  • Loading branch information
msumit committed Sep 12, 2021
1 parent d7aed84 commit 60759fc
Show file tree
Hide file tree
Showing 26 changed files with 818 additions and 213 deletions.
27 changes: 15 additions & 12 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,21 @@ def post_dag_run(dag_id, session):
.first()
)
if not dagrun_instance:
dag = current_app.dag_bag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
execution_date=logical_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=logical_date),
state=State.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
)
return dagrun_schema.dump(dag_run)
try:
dag = current_app.dag_bag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
execution_date=logical_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=logical_date),
state=State.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
)
return dagrun_schema.dump(dag_run)
except ValueError as ve:
raise BadRequest(detail=str(ve))

if dagrun_instance.execution_date == logical_date:
raise AlreadyExists(
Expand Down
8 changes: 7 additions & 1 deletion airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class DAGDetailSchema(DAGSchema):
dag_run_timeout = fields.Nested(TimeDeltaSchema, attribute="dagrun_timeout")
doc_md = fields.String()
default_view = fields.String()
params = fields.Dict()
params = fields.Method('get_params', dump_only=True)
tags = fields.Method("get_tags", dump_only=True)
is_paused = fields.Method("get_is_paused", dump_only=True)
is_active = fields.Method("get_is_active", dump_only=True)
Expand Down Expand Up @@ -115,6 +115,12 @@ def get_is_active(obj: DAG):
"""Checks entry in DAG table to see if this DAG is active"""
return obj.get_is_active()

@staticmethod
def get_params(obj: DAG):
"""Get the Params defined in a DAG"""
params = obj.params
return {k: v.dump() for k, v in params.items()}


class DAGCollection(NamedTuple):
"""List of DAGs with metadata"""
Expand Down
7 changes: 7 additions & 0 deletions airflow/api_connexion/schemas/task_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,18 @@ class TaskSchema(Schema):
template_fields = fields.List(fields.String(), dump_only=True)
sub_dag = fields.Nested(DAGSchema, dump_only=True)
downstream_task_ids = fields.List(fields.String(), dump_only=True)
params = fields.Method('get_params', dump_only=True)

def _get_class_reference(self, obj):
result = ClassReferenceSchema().dump(obj)
return result.data if hasattr(result, "data") else result

@staticmethod
def get_params(obj):
"""Get the Params defined in a Task"""
params = obj.params
return {k: v.dump() for k, v in params.items()}


class TaskCollection(NamedTuple):
"""List of Tasks with metadata"""
Expand Down
34 changes: 19 additions & 15 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,25 @@ def dag_backfill(args, dag=None):
dag_run_state=State.NONE,
)

dag.run(
start_date=args.start_date,
end_date=args.end_date,
mark_success=args.mark_success,
local=args.local,
donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
ignore_first_depends_on_past=args.ignore_first_depends_on_past,
ignore_task_deps=args.ignore_dependencies,
pool=args.pool,
delay_on_limit_secs=args.delay_on_limit,
verbose=args.verbose,
conf=run_conf,
rerun_failed_tasks=args.rerun_failed_tasks,
run_backwards=args.run_backwards,
)
try:
dag.run(
start_date=args.start_date,
end_date=args.end_date,
mark_success=args.mark_success,
local=args.local,
donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
ignore_first_depends_on_past=args.ignore_first_depends_on_past,
ignore_task_deps=args.ignore_dependencies,
pool=args.pool,
delay_on_limit_secs=args.delay_on_limit,
verbose=args.verbose,
conf=run_conf,
rerun_failed_tasks=args.rerun_failed_tasks,
run_backwards=args.run_backwards,
)
except ValueError as vr:
logging.error(str(vr))
sys.exit(1)


@cli_utils.action_logging
Expand Down
3 changes: 3 additions & 0 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ def task_test(args, dag=None):
if args.task_params:
passed_in_params = json.loads(args.task_params)
task.params.update(passed_in_params)

task.params.validate()

ti = _get_ti(task, args.execution_date_or_run_id, create_if_necssary=True)

try:
Expand Down
46 changes: 46 additions & 0 deletions airflow/example_dags/example_complex_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Example DAG demonstrating the usage of the complex params."""

from airflow import DAG
from airflow.models.param import Param
from airflow.operators.bash import BashOperator
from airflow.utils.dates import days_ago

with DAG(
"example_complex_params",
params={
'int_param': Param(10, type="integer", minimum=0, maximum=20), # non default int param
'str_param': Param(type="string", minLength=2, maxLength=4), # a mandatory str param
'old_param': 'old_way_of_passing',
'simple_param': Param('im_just_like_old_param'), # i.e. no type checking
'email_param': Param(
'example@example.com', type='string', format='idn-email', minLength=5, maxLength=255
),
},
schedule_interval=None,
start_date=days_ago(1),
tags=['example'],
) as dag:
all_params = BashOperator(
task_id='all_param',
bash_command="echo {{ params.int_param }} {{ params.str_param }} {{ params.old_param }} "
"{{ params.simple_param }} {{ params.email_param }} {{ params.task_param }}",
params={'task_param': Param('im_a_task_param', type='string')},
)
1 change: 1 addition & 0 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.errors import ImportError
from airflow.models.log import Log
from airflow.models.param import Param
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.sensorinstance import SensorInstance # noqa: F401
Expand Down
8 changes: 5 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.base import Operator
from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.models.taskinstance import Context, TaskInstance, clear_task_instances
from airflow.models.taskmixin import TaskMixin
Expand Down Expand Up @@ -148,7 +149,7 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any:
dag = kwargs.get('dag') or DagContext.get_current_dag()
if dag:
dag_args = copy.copy(dag.default_args) or {}
dag_params = copy.copy(dag.params) or {}
dag_params = copy.deepcopy(dag.params) or {}
task_group = TaskGroupContext.get_current_task_group(dag)
if task_group:
dag_args.update(task_group.default_args)
Expand Down Expand Up @@ -633,7 +634,7 @@ def __init__(
self.log.debug("max_retry_delay isn't a timedelta object, assuming secs")
self.max_retry_delay = timedelta(seconds=max_retry_delay)

self.params = params or {} # Available in templates!
self.params = ParamsDict(params)
if priority_weight is not None and not isinstance(priority_weight, int):
raise AirflowException(
f"`priority_weight` for task '{self.task_id}' only accepts integers, "
Expand Down Expand Up @@ -1092,7 +1093,7 @@ def render_template(
jinja_env = self.get_template_env()

# Imported here to avoid circular dependency
from airflow.models.dagparam import DagParam
from airflow.models.param import DagParam
from airflow.models.xcom_arg import XComArg

if isinstance(content, str):
Expand Down Expand Up @@ -1610,6 +1611,7 @@ def get_serialized_fields(cls):
'template_ext',
'template_fields',
'template_fields_renderers',
'params',
}
)
DagContext.pop_context_managed_dag()
Expand Down
27 changes: 26 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
from airflow.models.dagcode import DagCode
from airflow.models.dagparam import DagParam
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import DagRun
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances
from airflow.security import permissions
from airflow.stats import Stats
Expand Down Expand Up @@ -357,6 +357,9 @@ def __init__(
self.params.update(self.default_args['params'])
del self.default_args['params']

# check self.params and convert them into ParamsDict
self.params = ParamsDict(self.params)

if full_filepath:
warnings.warn(
"Passing full_filepath to DAG() is deprecated and has no effect",
Expand Down Expand Up @@ -473,6 +476,7 @@ def __init__(
self.render_template_as_native_obj = render_template_as_native_obj
self.tags = tags
self._task_group = TaskGroup.create_root(self)
self.validate_schedule_and_params()

def __repr__(self):
return f"<DAG: {self.dag_id}>"
Expand Down Expand Up @@ -2290,6 +2294,11 @@ def create_dagrun(
else:
data_interval = self.infer_automated_data_interval(logical_date)

# create a copy of params before validating
copied_params = copy.deepcopy(self.params)
copied_params.update(conf or {})
copied_params.validate()

run = DagRun(
dag_id=self.dag_id,
run_id=run_id,
Expand Down Expand Up @@ -2542,6 +2551,7 @@ def get_serialized_fields(cls):
'user_defined_filters',
'user_defined_macros',
'partial',
'params',
'_pickle_id',
'_log',
'is_subdag',
Expand Down Expand Up @@ -2576,6 +2586,21 @@ def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: Ed
"""
self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info

def validate_schedule_and_params(self):
"""
Validates & raise exception if there are any Params in the DAG which neither have a default value nor
have the null in schema['type'] list, but the DAG have a schedule_interval which is not None.
"""
if not self.timetable.can_run:
return

for k, v in self.params.items():
# As type can be an array, we would check if `null` is a allowed type or not
if v.default is None and ('type' not in v.schema or "null" not in v.schema['type']):
raise AirflowException(
'DAG Schedule must be None, if there are any required params without default values'
)


class DagTag(Base):
"""A tag name per dag, to allow quick filtering in the DAG view."""
Expand Down
47 changes: 8 additions & 39 deletions airflow/models/dagparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,14 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Optional
"""This module is deprecated. Please use :mod:`airflow.models.param`."""

from airflow.exceptions import AirflowException
import warnings

from airflow.models.param import DagParam # noqa

class DagParam:
"""
Class that represents a DAG run parameter.
It can be used to parameterize your dags. You can overwrite its value by setting it on conf
when you trigger your DagRun.
This can also be used in templates by accessing {{context.params}} dictionary.
**Example**:
with DAG(...) as dag:
EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
:param current_dag: Dag being used for parameter.
:type current_dag: airflow.models.DAG
:param name: key value which is used to set the parameter
:type name: str
:param default: Default value used if no parameter was set.
:type default: Any
"""

def __init__(self, current_dag, name: str, default: Optional[Any] = None):
if default:
current_dag.params[name] = default
self._name = name
self._default = default

def resolve(self, context: Dict) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
default = self._default
if not self._default:
default = context['params'].get(self._name, None)
resolved = context['dag_run'].conf.get(self._name, default)
if not resolved:
raise AirflowException(f'No value could be resolved for parameter {self._name}')
return resolved
warnings.warn(
"This module is deprecated. Please use `airflow.models.param`.",
DeprecationWarning,
stacklevel=2,
)
Loading

0 comments on commit 60759fc

Please sign in to comment.