Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate expected types for args for DAG, BaseOperator and TaskGroup #40269

Merged
merged 4 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
from airflow.utils.context import Context, context_get_outlet_events
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
from airflow.utils.helpers import validate_instance_args, validate_key
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -160,8 +160,6 @@ def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple
dag_args = copy.copy(dag.default_args)
dag_params = copy.deepcopy(dag.params)
if task_group:
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
raise TypeError("default_args must be a mapping")
dag_args.update(task_group.default_args)
return dag_args, dag_params

Expand All @@ -178,8 +176,6 @@ def get_merged_defaults(
raise TypeError("params must be a mapping")
params.update(task_params)
if task_default_args:
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(task_default_args, collections.abc.Mapping):
raise TypeError("default_args must be a mapping")
args.update(task_default_args)
with contextlib.suppress(KeyError):
params.update(task_default_args["params"] or {})
Expand Down Expand Up @@ -794,6 +790,40 @@ def say_hello_world(**context):
"executor",
}

_expected_args_types = {
"task_id": str,
"email": (str, Iterable),
"email_on_retry": bool,
"email_on_failure": bool,
"retries": int,
"retry_exponential_backoff": bool,
"depends_on_past": bool,
"ignore_first_depends_on_past": bool,
"wait_for_past_depends_before_skipping": bool,
"wait_for_downstream": bool,
"priority_weight": int,
"queue": str,
"pool": str,
"pool_slots": int,
"trigger_rule": str,
"run_as_user": str,
"task_concurrency": int,
"map_index_template": str,
"max_active_tis_per_dag": int,
"max_active_tis_per_dagrun": int,
"executor": str,
"do_xcom_push": bool,
"multiple_outputs": bool,
"doc": str,
"doc_md": str,
"doc_json": str,
"doc_yaml": str,
"doc_rst": str,
"task_display_name": str,
"logger_name": str,
"allow_nested_operators": bool,
}

# Defines if the operator supports lineage without manual definitions
supports_lineage = False

Expand Down Expand Up @@ -1078,6 +1108,8 @@ def __init__(
if SetupTeardownContext.active:
SetupTeardownContext.update_context_map(self)

validate_instance_args(self, self._expected_args_types)

def __eq__(self, other):
if type(self) is type(other):
# Use getattr() instead of __dict__ as __dict__ doesn't return
Expand Down
23 changes: 22 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.dates import cron_presets, date_range as utils_date_range
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.helpers import at_most_one, exactly_one, validate_key
from airflow.utils.helpers import at_most_one, exactly_one, validate_instance_args, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import (
Expand Down Expand Up @@ -468,6 +468,25 @@ class DAG(LoggingMixin):
"last_loaded",
}

_expected_args_types = {
"dag_id": str,
"description": str,
"max_active_tasks": int,
"max_active_runs": int,
"max_consecutive_failed_dag_runs": int,
"dagrun_timeout": timedelta,
"default_view": str,
"orientation": str,
"catchup": bool,
"doc_md": str,
"is_paused_upon_creation": bool,
"render_template_as_native_obj": bool,
"tags": list,
"auto_register": bool,
"fail_stop": bool,
"dag_display_name": str,
}

__serialized_fields: frozenset[str] | None = None

fileloc: str
Expand Down Expand Up @@ -744,6 +763,8 @@ def __init__(
# fileloc based only on the serialize dag
self._processor_dags_folder = None

validate_instance_args(self, self._expected_args_types)

def get_doc_md(self, doc_md: str | None) -> str | None:
if doc_md is None:
return doc_md
Expand Down
11 changes: 11 additions & 0 deletions airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def validate_key(k: str, max_length: int = 250):
)


def validate_instance_args(instance: object, expected_arg_types: dict[str, Any]) -> None:
"""Validate that the instance has the expected types for the arguments."""
for arg_name, expected_arg_type in expected_arg_types.items():
instance_arg_value = getattr(instance, arg_name, None)
if instance_arg_value is not None and not isinstance(instance_arg_value, expected_arg_type):
raise TypeError(
f"'{arg_name}' has an invalid type {type(instance_arg_value)} with value "
f"{instance_arg_value}, expected type is {expected_arg_type}"
)


def validate_group_key(k: str, max_length: int = 200):
"""Validate value used as a group key."""
if not isinstance(k, str):
Expand Down
13 changes: 12 additions & 1 deletion airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from airflow.models.taskmixin import DAGNode
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
from airflow.utils.helpers import validate_group_key, validate_instance_args

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -82,6 +82,15 @@ class TaskGroup(DAGNode):

used_group_ids: set[str | None]

_expected_args_types = {
"group_id": str,
"prefix_group_id": bool,
"tooltip": str,
"ui_color": str,
"ui_fgcolor": str,
"add_suffix_on_collision": bool,
}

def __init__(
self,
group_id: str | None,
Expand Down Expand Up @@ -160,6 +169,8 @@ def __init__(
self.upstream_task_ids = set()
self.downstream_task_ids = set()

validate_instance_args(self, self._expected_args_types)

def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
if self._group_id is None:
return
Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,34 @@ def test_logging_propogated_by_default(self, caplog):
# leaking a lot of state)
assert caplog.messages == ["test"]

def test_invalid_type_for_default_arg(self):
error_msg = "'max_active_tis_per_dag' has an invalid type <class 'str'> with value not_an_int, expected type is <class 'int'>"
with pytest.raises(TypeError, match=error_msg):
BaseOperator(task_id="test", default_args={"max_active_tis_per_dag": "not_an_int"})

def test_invalid_type_for_operator_arg(self):
error_msg = "'max_active_tis_per_dag' has an invalid type <class 'str'> with value not_an_int, expected type is <class 'int'>"
with pytest.raises(TypeError, match=error_msg):
BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int")

def test_baseoperator_defines_expected_arg_types(self):
operator = BaseOperator(task_id="test")

assert operator._expected_args_types is not None
assert isinstance(operator._expected_args_types, dict)

expected_args_types_subset = {
"task_id": str,
"email_on_retry": bool,
"email_on_failure": bool,
"retries": int,
"retry_exponential_backoff": bool,
"depends_on_past": bool,
}
assert set(operator._expected_args_types.items()).intersection(
set(expected_args_types_subset.items())
) == set(expected_args_types_subset.items())


def test_init_subclass_args():
class InitSubclassOp(BaseOperator):
Expand Down
24 changes: 24 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3928,6 +3928,30 @@ def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR
)


def test_invalid_type_for_args():
with pytest.raises(TypeError):
DAG("invalid-default-args", max_consecutive_failed_dag_runs="not_an_int")


def test_dag_defines_expected_args():
dag = DAG("dag_with_expected_args")

assert dag._expected_args_types is not None
assert isinstance(dag._expected_args_types, dict)

expected_args_types_subset = {
"dag_id": str,
"description": str,
"max_active_tasks": int,
"max_active_runs": int,
"max_consecutive_failed_dag_runs": int,
}

assert set(dag._expected_args_types.items()).intersection(set(expected_args_types_subset.items())) == set(
expected_args_types_subset.items()
)


class TestTaskClearingSetupTeardownBehavior:
"""
Task clearing behavior is mainly controlled by dag.partial_subset.
Expand Down
34 changes: 34 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
merge_dicts,
prune_dict,
validate_group_key,
validate_instance_args,
validate_key,
)
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -355,3 +356,36 @@ class SchedulerJobRunner(MockJobRunner):

class TriggererJobRunner(MockJobRunner):
job_type = "TriggererJob"


class ClassToValidateArgs:
def __init__(self, name, age, active):
self.name = name
self.age = age
self.active = active


# Edge cases
@pytest.mark.parametrize(
"instance, expected_arg_types",
[
(ClassToValidateArgs("Alice", 30, None), {"name": str, "age": int, "active": bool}),
(ClassToValidateArgs(None, 25, True), {"name": str, "age": int, "active": bool}),
],
)
def test_validate_instance_args_raises_no_error(instance, expected_arg_types):
validate_instance_args(instance, expected_arg_types)


# Error cases
@pytest.mark.parametrize(
"instance, expected_arg_types",
[
(ClassToValidateArgs("Alice", "thirty", True), {"name": str, "age": int, "active": bool}),
(ClassToValidateArgs("Bob", 25, "yes"), {"name": str, "age": int, "active": bool}),
(ClassToValidateArgs(123, 25, True), {"name": str, "age": int, "active": bool}),
],
)
def test_validate_instance_args_raises_error(instance, expected_arg_types):
with pytest.raises(TypeError):
validate_instance_args(instance, expected_arg_types)
28 changes: 28 additions & 0 deletions tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,3 +1630,31 @@ def work(): ...
assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"}
assert set(t1.operator.downstream_task_ids) == set()
assert set(t2.operator.downstream_task_ids) == set()


def test_task_group_with_invalid_arg_type_raises_error():
error_msg = "'ui_color' has an invalid type <class 'int'> with value 123, expected type is <class 'str'>"
with DAG(dag_id="dag_with_tg_invalid_arg_type"):
with pytest.raises(TypeError, match=error_msg):
with TaskGroup("group_1", ui_color=123):
EmptyOperator(task_id="task1")


def test_task_group_defines_expected_arg_types():
with DAG(dag_id="dag_with_tg_valid_arg_types"):
with TaskGroup("group_1", ui_color="red") as tg:
EmptyOperator(task_id="task1")

assert tg._expected_args_types is not None
assert isinstance(tg._expected_args_types, dict)

expected_args_types_subset = {
"group_id": str,
"prefix_group_id": bool,
"tooltip": str,
"ui_color": str,
}

assert set(tg._expected_args_types.items()).intersection(set(expected_args_types_subset.items())) == set(
expected_args_types_subset.items()
)
Loading