Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced Params using json-schema #17100

Merged
merged 7 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,21 @@ def post_dag_run(dag_id, session):
.first()
)
if not dagrun_instance:
dag = current_app.dag_bag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
execution_date=logical_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=logical_date),
state=State.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
)
return dagrun_schema.dump(dag_run)
try:
dag = current_app.dag_bag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
execution_date=logical_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=logical_date),
state=State.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
)
return dagrun_schema.dump(dag_run)
except ValueError as ve:
raise BadRequest(detail=str(ve))

if dagrun_instance.execution_date == logical_date:
raise AlreadyExists(
Expand Down
8 changes: 7 additions & 1 deletion airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class DAGDetailSchema(DAGSchema):
dag_run_timeout = fields.Nested(TimeDeltaSchema, attribute="dagrun_timeout")
doc_md = fields.String()
default_view = fields.String()
params = fields.Dict()
params = fields.Method('get_params', dump_only=True)
tags = fields.Method("get_tags", dump_only=True)
is_paused = fields.Method("get_is_paused", dump_only=True)
is_active = fields.Method("get_is_active", dump_only=True)
Expand Down Expand Up @@ -115,6 +115,12 @@ def get_is_active(obj: DAG):
"""Checks entry in DAG table to see if this DAG is active"""
return obj.get_is_active()

@staticmethod
def get_params(obj: DAG):
"""Get the Params defined in a DAG"""
params = obj.params
return {k: v.dump() for k, v in params.items()}


class DAGCollection(NamedTuple):
"""List of DAGs with metadata"""
Expand Down
7 changes: 7 additions & 0 deletions airflow/api_connexion/schemas/task_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,18 @@ class TaskSchema(Schema):
template_fields = fields.List(fields.String(), dump_only=True)
sub_dag = fields.Nested(DAGSchema, dump_only=True)
downstream_task_ids = fields.List(fields.String(), dump_only=True)
params = fields.Method('get_params', dump_only=True)

def _get_class_reference(self, obj):
result = ClassReferenceSchema().dump(obj)
return result.data if hasattr(result, "data") else result

@staticmethod
def get_params(obj):
"""Get the Params defined in a Task"""
params = obj.params
return {k: v.dump() for k, v in params.items()}


class TaskCollection(NamedTuple):
"""List of Tasks with metadata"""
Expand Down
34 changes: 19 additions & 15 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,25 @@ def dag_backfill(args, dag=None):
dag_run_state=State.NONE,
)

dag.run(
start_date=args.start_date,
end_date=args.end_date,
mark_success=args.mark_success,
local=args.local,
donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
ignore_first_depends_on_past=args.ignore_first_depends_on_past,
ignore_task_deps=args.ignore_dependencies,
pool=args.pool,
delay_on_limit_secs=args.delay_on_limit,
verbose=args.verbose,
conf=run_conf,
rerun_failed_tasks=args.rerun_failed_tasks,
run_backwards=args.run_backwards,
)
try:
dag.run(
start_date=args.start_date,
end_date=args.end_date,
mark_success=args.mark_success,
local=args.local,
donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
ignore_first_depends_on_past=args.ignore_first_depends_on_past,
ignore_task_deps=args.ignore_dependencies,
pool=args.pool,
delay_on_limit_secs=args.delay_on_limit,
verbose=args.verbose,
conf=run_conf,
rerun_failed_tasks=args.rerun_failed_tasks,
run_backwards=args.run_backwards,
)
except ValueError as vr:
print(str(vr))
sys.exit(1)


@cli_utils.action_logging
Expand Down
4 changes: 4 additions & 0 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ def task_test(args, dag=None):
if args.task_params:
passed_in_params = json.loads(args.task_params)
task.params.update(passed_in_params)

if task.params:
task.params.validate()

ti = _get_ti(task, args.execution_date_or_run_id, create_if_necssary=True)

try:
Expand Down
1 change: 1 addition & 0 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.errors import ImportError
from airflow.models.log import Log
from airflow.models.param import Param
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.sensorinstance import SensorInstance # noqa: F401
Expand Down
8 changes: 5 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.base import Operator
from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.models.taskinstance import Context, TaskInstance, clear_task_instances
from airflow.models.taskmixin import TaskMixin
Expand Down Expand Up @@ -148,7 +149,7 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any:
dag = kwargs.get('dag') or DagContext.get_current_dag()
if dag:
dag_args = copy.copy(dag.default_args) or {}
dag_params = copy.copy(dag.params) or {}
dag_params = copy.deepcopy(dag.params) or {}
task_group = TaskGroupContext.get_current_task_group(dag)
if task_group:
dag_args.update(task_group.default_args)
Expand Down Expand Up @@ -633,7 +634,7 @@ def __init__(
self.log.debug("max_retry_delay isn't a timedelta object, assuming secs")
self.max_retry_delay = timedelta(seconds=max_retry_delay)

self.params = params or {} # Available in templates!
self.params = ParamsDict(params)
if priority_weight is not None and not isinstance(priority_weight, int):
raise AirflowException(
f"`priority_weight` for task '{self.task_id}' only accepts integers, "
Expand Down Expand Up @@ -1092,7 +1093,7 @@ def render_template(
jinja_env = self.get_template_env()

# Imported here to avoid circular dependency
from airflow.models.dagparam import DagParam
from airflow.models.param import DagParam
from airflow.models.xcom_arg import XComArg

if isinstance(content, str):
Expand Down Expand Up @@ -1610,6 +1611,7 @@ def get_serialized_fields(cls):
'template_ext',
'template_fields',
'template_fields_renderers',
'params',
}
)
DagContext.pop_context_managed_dag()
Expand Down
27 changes: 26 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
from airflow.models.dagcode import DagCode
from airflow.models.dagparam import DagParam
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import DagRun
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances
from airflow.security import permissions
from airflow.stats import Stats
Expand Down Expand Up @@ -357,6 +357,9 @@ def __init__(
self.params.update(self.default_args['params'])
del self.default_args['params']

# check self.params and convert them into ParamsDict
self.params = ParamsDict(self.params)

if full_filepath:
warnings.warn(
"Passing full_filepath to DAG() is deprecated and has no effect",
Expand Down Expand Up @@ -473,6 +476,7 @@ def __init__(
self.render_template_as_native_obj = render_template_as_native_obj
self.tags = tags
self._task_group = TaskGroup.create_root(self)
self.validate_schedule_and_params()

def __repr__(self):
return f"<DAG: {self.dag_id}>"
Expand Down Expand Up @@ -2290,6 +2294,11 @@ def create_dagrun(
else:
data_interval = self.infer_automated_data_interval(logical_date)

# create a copy of params before validating
copied_params = copy.deepcopy(self.params)
copied_params.update(conf or {})
copied_params.validate()

run = DagRun(
dag_id=self.dag_id,
run_id=run_id,
Expand Down Expand Up @@ -2542,6 +2551,7 @@ def get_serialized_fields(cls):
'user_defined_filters',
'user_defined_macros',
'partial',
'params',
'_pickle_id',
'_log',
'is_subdag',
Expand Down Expand Up @@ -2576,6 +2586,21 @@ def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: Ed
"""
self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info

def validate_schedule_and_params(self):
"""
Validates & raise exception if there are any Params in the DAG which neither have a default value nor
have the null in schema['type'] list, but the DAG have a schedule_interval which is not None.
"""
if not self.timetable.can_run:
return

for k, v in self.params.items():
# As type can be an array, we would check if `null` is a allowed type or not
if v.default is None and ("type" not in v.schema or "null" not in v.schema["type"]):
raise AirflowException(
"DAG Schedule must be None, if there are any required params without default values"
)


class DagTag(Base):
"""A tag name per dag, to allow quick filtering in the DAG view."""
Expand Down
47 changes: 8 additions & 39 deletions airflow/models/dagparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,14 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Optional
"""This module is deprecated. Please use :mod:`airflow.models.param`."""

from airflow.exceptions import AirflowException
import warnings

from airflow.models.param import DagParam # noqa

class DagParam:
"""
Class that represents a DAG run parameter.

It can be used to parameterize your dags. You can overwrite its value by setting it on conf
when you trigger your DagRun.

This can also be used in templates by accessing {{context.params}} dictionary.

**Example**:

with DAG(...) as dag:
EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))

:param current_dag: Dag being used for parameter.
:type current_dag: airflow.models.DAG
:param name: key value which is used to set the parameter
:type name: str
:param default: Default value used if no parameter was set.
:type default: Any
"""

def __init__(self, current_dag, name: str, default: Optional[Any] = None):
if default:
current_dag.params[name] = default
self._name = name
self._default = default

def resolve(self, context: Dict) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
default = self._default
if not self._default:
default = context['params'].get(self._name, None)
resolved = context['dag_run'].conf.get(self._name, default)
if not resolved:
raise AirflowException(f'No value could be resolved for parameter {self._name}')
return resolved
warnings.warn(
"This module is deprecated. Please use `airflow.models.param`.",
DeprecationWarning,
stacklevel=2,
)
Loading