From 5b2765de125cae3b1564eda7f83157fc31051cc0 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Sun, 16 Jun 2024 20:37:23 +0530 Subject: [PATCH 1/4] Validate expected types for args for DAG, BaseOperator and TaskGroup --- airflow/models/baseoperator.py | 42 ++++++++++++++++++++++++++++++---- airflow/models/dag.py | 23 ++++++++++++++++++- airflow/utils/helpers.py | 11 +++++++++ airflow/utils/task_group.py | 13 ++++++++++- tests/utils/test_helpers.py | 34 +++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 7 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index bbd629cfc156c..76a09809f4ab6 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -95,7 +95,7 @@ from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.edgemodifier import EdgeModifier -from airflow.utils.helpers import validate_key +from airflow.utils.helpers import validate_instance_args, validate_key from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session @@ -160,8 +160,6 @@ def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple dag_args = copy.copy(dag.default_args) dag_params = copy.deepcopy(dag.params) if task_group: - if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): - raise TypeError("default_args must be a mapping") dag_args.update(task_group.default_args) return dag_args, dag_params @@ -178,8 +176,6 @@ def get_merged_defaults( raise TypeError("params must be a mapping") params.update(task_params) if task_default_args: - if not isinstance(task_default_args, collections.abc.Mapping): - raise TypeError("default_args must be a mapping") args.update(task_default_args) with contextlib.suppress(KeyError): params.update(task_default_args["params"] or {}) @@ -794,6 +790,40 @@ def say_hello_world(**context): "executor", } + _expected_args_types = { + "task_id": str, + "email": (str, Iterable), + "email_on_retry": bool, + "email_on_failure": bool, + "retries": int, + "retry_exponential_backoff": bool, + "depends_on_past": bool, + "ignore_first_depends_on_past": bool, + "wait_for_past_depends_before_skipping": bool, + "wait_for_downstream": bool, + "priority_weight": int, + "queue": str, + "pool": str, + "pool_slots": int, + "trigger_rule": str, + "run_as_user": str, + "task_concurrency": int, + "map_index_template": str, + "max_active_tis_per_dag": int, + "max_active_tis_per_dagrun": int, + "executor": str, + "do_xcom_push": bool, + "multiple_outputs": bool, + "doc": str, + "doc_md": str, + "doc_json": str, + "doc_yaml": str, + "doc_rst": str, + "task_display_name": str, + "logger_name": str, + "allow_nested_operators": bool, + } + # Defines if the operator supports lineage without manual definitions supports_lineage = False @@ -1078,6 +1108,8 @@ def __init__( if SetupTeardownContext.active: SetupTeardownContext.update_context_map(self) + validate_instance_args(self, self._expected_args_types) + def __eq__(self, other): if type(self) is type(other): # Use getattr() instead of __dict__ as __dict__ doesn't return diff --git a/airflow/models/dag.py b/airflow/models/dag.py index b2de299e64646..668137169382d 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -128,7 +128,7 @@ from airflow.utils.dag_cycle_tester import check_cycle from airflow.utils.dates import cron_presets, date_range as utils_date_range from airflow.utils.decorators import fixup_decorator_warning_stack -from airflow.utils.helpers import at_most_one, exactly_one, validate_key +from airflow.utils.helpers import at_most_one, exactly_one, validate_instance_args, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import ( @@ -468,6 +468,25 @@ class DAG(LoggingMixin): "last_loaded", } + _expected_args_types = { + "dag_id": str, + "description": str, + "max_active_tasks": int, + "max_active_runs": int, + "max_consecutive_failed_dag_runs": int, + "dagrun_timeout": timedelta, + "default_view": str, + "orientation": str, + "catchup": bool, + "doc_md": str, + "is_paused_upon_creation": bool, + "render_template_as_native_obj": bool, + "tags": list, + "auto_register": bool, + "fail_stop": bool, + "dag_display_name": str, + } + __serialized_fields: frozenset[str] | None = None fileloc: str @@ -744,6 +763,8 @@ def __init__( # fileloc based only on the serialize dag self._processor_dags_folder = None + validate_instance_args(self, self._expected_args_types) + def get_doc_md(self, doc_md: str | None) -> str | None: if doc_md is None: return doc_md diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 3ad87bd104035..8d14aadc21fa9 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -60,6 +60,17 @@ def validate_key(k: str, max_length: int = 250): ) +def validate_instance_args(instance: object, expected_arg_types: dict[str, Any]) -> None: + """Validate that the instance has the expected types for the arguments.""" + for arg_name, expected_arg_type in expected_arg_types.items(): + instance_arg_value = getattr(instance, arg_name, None) + if instance_arg_value is not None and not isinstance(instance_arg_value, expected_arg_type): + raise TypeError( + f"'{arg_name}' has an invalid type {type(instance_arg_value)} with value " + f"{instance_arg_value}, expected type is {expected_arg_type}" + ) + + def validate_group_key(k: str, max_length: int = 200): """Validate value used as a group key.""" if not isinstance(k, str): diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 37b03e1907f15..c7c3279990fff 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -36,7 +36,7 @@ ) from airflow.models.taskmixin import DAGNode from airflow.serialization.enums import DagAttributeTypes -from airflow.utils.helpers import validate_group_key +from airflow.utils.helpers import validate_group_key, validate_instance_args if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -82,6 +82,15 @@ class TaskGroup(DAGNode): used_group_ids: set[str | None] + _expected_args_types = { + "group_id": str, + "prefix_group_id": bool, + "tooltip": str, + "ui_color": str, + "ui_fgcolor": str, + "add_suffix_on_collision": bool, + } + def __init__( self, group_id: str | None, @@ -160,6 +169,8 @@ def __init__( self.upstream_task_ids = set() self.downstream_task_ids = set() + validate_instance_args(self, self._expected_args_types) + def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): if self._group_id is None: return diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 75b04b14c7b50..27ef5a76db5b7 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -33,6 +33,7 @@ merge_dicts, prune_dict, validate_group_key, + validate_instance_args, validate_key, ) from airflow.utils.types import NOTSET @@ -355,3 +356,36 @@ class SchedulerJobRunner(MockJobRunner): class TriggererJobRunner(MockJobRunner): job_type = "TriggererJob" + + +class ClassToValidateArgs: + def __init__(self, name, age, active): + self.name = name + self.age = age + self.active = active + + +# Edge cases +@pytest.mark.parametrize( + "instance, expected_arg_types", + [ + (ClassToValidateArgs("Alice", 30, None), {"name": str, "age": int, "active": bool}), + (ClassToValidateArgs(None, 25, True), {"name": str, "age": int, "active": bool}), + ], +) +def test_validate_instance_args_raises_no_error(instance, expected_arg_types): + validate_instance_args(instance, expected_arg_types) + + +# Error cases +@pytest.mark.parametrize( + "instance, expected_arg_types", + [ + (ClassToValidateArgs("Alice", "thirty", True), {"name": str, "age": int, "active": bool}), + (ClassToValidateArgs("Bob", 25, "yes"), {"name": str, "age": int, "active": bool}), + (ClassToValidateArgs(123, 25, True), {"name": str, "age": int, "active": bool}), + ], +) +def test_validate_instance_args_raises_error(instance, expected_arg_types): + with pytest.raises(TypeError): + validate_instance_args(instance, expected_arg_types) From 4882e249dff3c63764dd194773a96d5a7b125e0d Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 17 Jun 2024 00:27:01 +0530 Subject: [PATCH 2/4] Add tests --- tests/models/test_baseoperator.py | 28 ++++++++++++++++++++++++++++ tests/models/test_dag.py | 24 ++++++++++++++++++++++++ tests/utils/test_task_group.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 06310cd395ca0..36ab4a08aa826 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -811,6 +811,34 @@ def test_logging_propogated_by_default(self, caplog): # leaking a lot of state) assert caplog.messages == ["test"] + def test_invalid_type_for_default_arg(self): + error_msg = "'max_active_tis_per_dag' has an invalid type with value not_an_int, expected type is " + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test", default_args={"max_active_tis_per_dag": "not_an_int"}) + + def test_invalid_type_for_operator_arg(self): + error_msg = "'max_active_tis_per_dag' has an invalid type with value not_an_int, expected type is " + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int") + + def test_baseoperator_defines_expected_arg_types(self): + operator = BaseOperator(task_id="test") + + assert operator._expected_args_types is not None + assert isinstance(operator._expected_args_types, dict) + + expected_args_types_subset = { + "task_id": str, + "email_on_retry": bool, + "email_on_failure": bool, + "retries": int, + "retry_exponential_backoff": bool, + "depends_on_past": bool, + } + assert set(operator._expected_args_types.items()).intersection( + set(expected_args_types_subset.items()) + ) == set(expected_args_types_subset.items()) + def test_init_subclass_args(): class InitSubclassOp(BaseOperator): diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 6d931ffbdc53c..5459a70c46a3c 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -3928,6 +3928,30 @@ def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR ) +def test_invalid_type_for_args(): + with pytest.raises(TypeError): + DAG("invalid-default-args", max_consecutive_failed_dag_runs="not_an_int") + + +def test_dag_defines_expected_args(): + dag = DAG("dag_with_expected_args") + + assert dag._expected_args_types is not None + assert isinstance(dag._expected_args_types, dict) + + expected_args_types_subset = { + "dag_id": str, + "description": str, + "max_active_tasks": int, + "max_active_runs": int, + "max_consecutive_failed_dag_runs": int, + } + + assert set(dag._expected_args_types.items()).intersection(set(expected_args_types_subset.items())) == set( + expected_args_types_subset.items() + ) + + class TestTaskClearingSetupTeardownBehavior: """ Task clearing behavior is mainly controlled by dag.partial_subset. diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 359980ec0956f..418083390c194 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -1630,3 +1630,31 @@ def work(): ... assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"} assert set(t1.operator.downstream_task_ids) == set() assert set(t2.operator.downstream_task_ids) == set() + + +def test_task_group_with_invalid_arg_type_raises_error(): + error_msg = "'ui_color' has an invalid type with value 123, expected type is " + with DAG(dag_id="dag_with_tg_invalid_arg_type"): + with pytest.raises(TypeError, match=error_msg): + with TaskGroup("group_1", ui_color=123): + EmptyOperator(task_id="task1") + + +def test_task_group_defines_expected_arg_types(): + with DAG(dag_id="dag_with_tg_valid_arg_types"): + with TaskGroup("group_1", ui_color="red") as tg: + EmptyOperator(task_id="task1") + + assert tg._expected_args_types is not None + assert isinstance(tg._expected_args_types, dict) + + expected_args_types_subset = { + "group_id": str, + "prefix_group_id": bool, + "tooltip": str, + "ui_color": str, + } + + assert set(tg._expected_args_types.items()).intersection(set(expected_args_types_subset.items())) == set( + expected_args_types_subset.items() + ) From 0813a7a7870e90968c68bde2aaac4515dd2106d4 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 17 Jun 2024 00:55:17 +0530 Subject: [PATCH 3/4] Apply suggestions from code review --- airflow/models/baseoperator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 76a09809f4ab6..776cec076f1e7 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -160,6 +160,8 @@ def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple dag_args = copy.copy(dag.default_args) dag_params = copy.deepcopy(dag.params) if task_group: + if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping): + raise TypeError("default_args must be a mapping") dag_args.update(task_group.default_args) return dag_args, dag_params @@ -176,6 +178,8 @@ def get_merged_defaults( raise TypeError("params must be a mapping") params.update(task_params) if task_default_args: + if not isinstance(task_default_args, collections.abc.Mapping): + raise TypeError("default_args must be a mapping") args.update(task_default_args) with contextlib.suppress(KeyError): params.update(task_default_args["params"] or {}) From 271e6a8c706a4619d90045db3f63b8955f274559 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 17 Jun 2024 16:42:01 +0530 Subject: [PATCH 4/4] Addres @uranusjr's review comments --- airflow/models/baseoperator.py | 77 +++++++++++++++++-------------- airflow/models/dag.py | 47 +++++++++++-------- airflow/utils/task_group.py | 26 +++++++---- tests/models/test_baseoperator.py | 19 ++------ tests/models/test_dag.py | 19 ++------ tests/utils/test_task_group.py | 20 ++------ 6 files changed, 98 insertions(+), 110 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 776cec076f1e7..5d58cb80dc32c 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -521,6 +521,47 @@ def __new__(cls, name, bases, namespace, **kwargs): return new_cls +# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the +# correct type. This is a temporary solution until we find a more sophisticated method for argument +# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not +# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python +# version that supports `get_type_hints` effectively or find a better approach, we can replace this +# manual type-checking method. +BASEOPERATOR_ARGS_EXPECTED_TYPES = { + "task_id": str, + "email": (str, Iterable), + "email_on_retry": bool, + "email_on_failure": bool, + "retries": int, + "retry_exponential_backoff": bool, + "depends_on_past": bool, + "ignore_first_depends_on_past": bool, + "wait_for_past_depends_before_skipping": bool, + "wait_for_downstream": bool, + "priority_weight": int, + "queue": str, + "pool": str, + "pool_slots": int, + "trigger_rule": str, + "run_as_user": str, + "task_concurrency": int, + "map_index_template": str, + "max_active_tis_per_dag": int, + "max_active_tis_per_dagrun": int, + "executor": str, + "do_xcom_push": bool, + "multiple_outputs": bool, + "doc": str, + "doc_md": str, + "doc_json": str, + "doc_yaml": str, + "doc_rst": str, + "task_display_name": str, + "logger_name": str, + "allow_nested_operators": bool, +} + + @total_ordering class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): r""" @@ -794,40 +835,6 @@ def say_hello_world(**context): "executor", } - _expected_args_types = { - "task_id": str, - "email": (str, Iterable), - "email_on_retry": bool, - "email_on_failure": bool, - "retries": int, - "retry_exponential_backoff": bool, - "depends_on_past": bool, - "ignore_first_depends_on_past": bool, - "wait_for_past_depends_before_skipping": bool, - "wait_for_downstream": bool, - "priority_weight": int, - "queue": str, - "pool": str, - "pool_slots": int, - "trigger_rule": str, - "run_as_user": str, - "task_concurrency": int, - "map_index_template": str, - "max_active_tis_per_dag": int, - "max_active_tis_per_dagrun": int, - "executor": str, - "do_xcom_push": bool, - "multiple_outputs": bool, - "doc": str, - "doc_md": str, - "doc_json": str, - "doc_yaml": str, - "doc_rst": str, - "task_display_name": str, - "logger_name": str, - "allow_nested_operators": bool, - } - # Defines if the operator supports lineage without manual definitions supports_lineage = False @@ -1112,7 +1119,7 @@ def __init__( if SetupTeardownContext.active: SetupTeardownContext.update_context_map(self) - validate_instance_args(self, self._expected_args_types) + validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) def __eq__(self, other): if type(self) is type(other): diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 668137169382d..97e39d4a93f08 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -341,6 +341,32 @@ def _create_orm_dagrun( return run +# TODO: The following mapping is used to validate that the arguments passed to the DAG are of the correct +# type. This is a temporary solution until we find a more sophisticated method for argument validation. +# One potential method is to use `get_type_hints` from the typing module. However, this is not fully +# compatible with future annotations for Python versions below 3.10. Once we require a minimum Python +# version that supports `get_type_hints` effectively or find a better approach, we can replace this +# manual type-checking method. +DAG_ARGS_EXPECTED_TYPES = { + "dag_id": str, + "description": str, + "max_active_tasks": int, + "max_active_runs": int, + "max_consecutive_failed_dag_runs": int, + "dagrun_timeout": timedelta, + "default_view": str, + "orientation": str, + "catchup": bool, + "doc_md": str, + "is_paused_upon_creation": bool, + "render_template_as_native_obj": bool, + "tags": list, + "auto_register": bool, + "fail_stop": bool, + "dag_display_name": str, +} + + @functools.total_ordering class DAG(LoggingMixin): """ @@ -468,25 +494,6 @@ class DAG(LoggingMixin): "last_loaded", } - _expected_args_types = { - "dag_id": str, - "description": str, - "max_active_tasks": int, - "max_active_runs": int, - "max_consecutive_failed_dag_runs": int, - "dagrun_timeout": timedelta, - "default_view": str, - "orientation": str, - "catchup": bool, - "doc_md": str, - "is_paused_upon_creation": bool, - "render_template_as_native_obj": bool, - "tags": list, - "auto_register": bool, - "fail_stop": bool, - "dag_display_name": str, - } - __serialized_fields: frozenset[str] | None = None fileloc: str @@ -763,7 +770,7 @@ def __init__( # fileloc based only on the serialize dag self._processor_dags_folder = None - validate_instance_args(self, self._expected_args_types) + validate_instance_args(self, DAG_ARGS_EXPECTED_TYPES) def get_doc_md(self, doc_md: str | None) -> str | None: if doc_md is None: diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index c7c3279990fff..4c158498d1570 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -49,6 +49,21 @@ from airflow.models.taskmixin import DependencyMixin from airflow.utils.edgemodifier import EdgeModifier +# TODO: The following mapping is used to validate that the arguments passed to the TaskGroup are of the +# correct type. This is a temporary solution until we find a more sophisticated method for argument +# validation. One potential method is to use get_type_hints from the typing module. However, this is not +# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python +# version that supports `get_type_hints` effectively or find a better approach, we can replace this +# manual type-checking method. +TASKGROUP_ARGS_EXPECTED_TYPES = { + "group_id": str, + "prefix_group_id": bool, + "tooltip": str, + "ui_color": str, + "ui_fgcolor": str, + "add_suffix_on_collision": bool, +} + class TaskGroup(DAGNode): """ @@ -82,15 +97,6 @@ class TaskGroup(DAGNode): used_group_ids: set[str | None] - _expected_args_types = { - "group_id": str, - "prefix_group_id": bool, - "tooltip": str, - "ui_color": str, - "ui_fgcolor": str, - "add_suffix_on_collision": bool, - } - def __init__( self, group_id: str | None, @@ -169,7 +175,7 @@ def __init__( self.upstream_task_ids = set() self.downstream_task_ids = set() - validate_instance_args(self, self._expected_args_types) + validate_instance_args(self, TASKGROUP_ARGS_EXPECTED_TYPES) def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): if self._group_id is None: diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 36ab4a08aa826..a73db360aabb2 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -32,6 +32,7 @@ from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning from airflow.lineage.entities import File from airflow.models.baseoperator import ( + BASEOPERATOR_ARGS_EXPECTED_TYPES, BaseOperator, BaseOperatorMeta, chain, @@ -821,23 +822,11 @@ def test_invalid_type_for_operator_arg(self): with pytest.raises(TypeError, match=error_msg): BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int") - def test_baseoperator_defines_expected_arg_types(self): + @mock.patch("airflow.models.baseoperator.validate_instance_args") + def test_baseoperator_init_validates_arg_types(self, mock_validate_instance_args): operator = BaseOperator(task_id="test") - assert operator._expected_args_types is not None - assert isinstance(operator._expected_args_types, dict) - - expected_args_types_subset = { - "task_id": str, - "email_on_retry": bool, - "email_on_failure": bool, - "retries": int, - "retry_exponential_backoff": bool, - "depends_on_past": bool, - } - assert set(operator._expected_args_types.items()).intersection( - set(expected_args_types_subset.items()) - ) == set(expected_args_types_subset.items()) + mock_validate_instance_args.assert_called_once_with(operator, BASEOPERATOR_ARGS_EXPECTED_TYPES) def test_init_subclass_args(): diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 5459a70c46a3c..0a1247d4ed997 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -56,6 +56,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dag import ( DAG, + DAG_ARGS_EXPECTED_TYPES, DagModel, DagOwnerAttributes, DagTag, @@ -3933,23 +3934,11 @@ def test_invalid_type_for_args(): DAG("invalid-default-args", max_consecutive_failed_dag_runs="not_an_int") -def test_dag_defines_expected_args(): +@mock.patch("airflow.models.dag.validate_instance_args") +def test_dag_init_validates_arg_types(mock_validate_instance_args): dag = DAG("dag_with_expected_args") - assert dag._expected_args_types is not None - assert isinstance(dag._expected_args_types, dict) - - expected_args_types_subset = { - "dag_id": str, - "description": str, - "max_active_tasks": int, - "max_active_runs": int, - "max_consecutive_failed_dag_runs": int, - } - - assert set(dag._expected_args_types.items()).intersection(set(expected_args_types_subset.items())) == set( - expected_args_types_subset.items() - ) + mock_validate_instance_args.assert_called_once_with(dag, DAG_ARGS_EXPECTED_TYPES) class TestTaskClearingSetupTeardownBehavior: diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 418083390c194..df593d5b787c1 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -18,6 +18,7 @@ from __future__ import annotations from datetime import timedelta +from unittest import mock import pendulum import pytest @@ -37,7 +38,7 @@ from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator from airflow.utils.dag_edges import dag_edges -from airflow.utils.task_group import TaskGroup, task_group_to_dict +from airflow.utils.task_group import TASKGROUP_ARGS_EXPECTED_TYPES, TaskGroup, task_group_to_dict from tests.models import DEFAULT_DATE @@ -1640,21 +1641,10 @@ def test_task_group_with_invalid_arg_type_raises_error(): EmptyOperator(task_id="task1") -def test_task_group_defines_expected_arg_types(): +@mock.patch("airflow.utils.task_group.validate_instance_args") +def test_task_group_init_validates_arg_types(mock_validate_instance_args): with DAG(dag_id="dag_with_tg_valid_arg_types"): with TaskGroup("group_1", ui_color="red") as tg: EmptyOperator(task_id="task1") - assert tg._expected_args_types is not None - assert isinstance(tg._expected_args_types, dict) - - expected_args_types_subset = { - "group_id": str, - "prefix_group_id": bool, - "tooltip": str, - "ui_color": str, - } - - assert set(tg._expected_args_types.items()).intersection(set(expected_args_types_subset.items())) == set( - expected_args_types_subset.items() - ) + mock_validate_instance_args.assert_called_with(tg, TASKGROUP_ARGS_EXPECTED_TYPES)