diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py index ba1e111e4031c..9f60f971f941e 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py @@ -58,7 +58,7 @@ class BaseHITLDetail(BaseModel): body: str | None = None defaults: list[str] | None = None multiple: bool = False - params: dict[str, Any] = Field(default_factory=dict) + params: Mapping = Field(default_factory=dict) assigned_users: list[HITLUser] = Field(default_factory=list) created_at: datetime @@ -74,7 +74,20 @@ class BaseHITLDetail(BaseModel): @classmethod def get_params(cls, params: dict[str, Any]) -> dict[str, Any]: """Convert params attribute to dict representation.""" - return {k: v.dump() if getattr(v, "dump", None) else v for k, v in params.items()} + return { + key: value + if BaseHITLDetail._is_param(value) + else { + "value": value, + "description": None, + "schema": {}, + } + for key, value in params.items() + } + + @staticmethod + def _is_param(value: Any) -> bool: + return isinstance(value, dict) and all(key in value for key in ("description", "schema", "value")) class HITLDetail(BaseHITLDetail): diff --git a/airflow-core/src/airflow/ui/src/utils/hitl.ts b/airflow-core/src/airflow/ui/src/utils/hitl.ts index 7d38a09374d00..bd3a405175461 100644 --- a/airflow-core/src/airflow/ui/src/utils/hitl.ts +++ b/airflow-core/src/airflow/ui/src/utils/hitl.ts @@ -19,7 +19,7 @@ import type { TFunction } from "i18next"; import type { HITLDetail } from "openapi/requests/types.gen"; -import type { ParamsSpec } from "src/queries/useDagParams"; +import type { ParamSchema, ParamsSpec } from "src/queries/useDagParams"; export type HITLResponseParams = { chosen_options?: Array; @@ -70,7 +70,7 @@ export const getHITLParamsDict = ( searchParams: URLSearchParams, ): ParamsSpec => { const paramsDict: ParamsSpec = {}; - const { preloadedHITLOptions, preloadedHITLParams } = getPreloadHITLFormData(searchParams, hitlDetail); + const { preloadedHITLOptions } = getPreloadHITLFormData(searchParams, hitlDetail); const isApprovalTask = hitlDetail.options.includes("Approve") && hitlDetail.options.includes("Reject") && @@ -108,27 +108,36 @@ export const getHITLParamsDict = ( const sourceParams = hitlDetail.response_received ? hitlDetail.params_input : hitlDetail.params; Object.entries(sourceParams ?? {}).forEach(([key, value]) => { - const valueType = typeof value === "number" ? "number" : "string"; + if (!hitlDetail.params) { + return; + } + const paramData = hitlDetail.params[key] as ParamsSpec | undefined; + + const description: string = + paramData && typeof paramData.description === "string" ? paramData.description : ""; + + const schema: ParamSchema = { + const: undefined, + description_md: "", + enum: undefined, + examples: undefined, + format: undefined, + items: undefined, + maximum: undefined, + maxLength: undefined, + minimum: undefined, + minLength: undefined, + section: undefined, + title: key, + type: typeof value === "number" ? "number" : "string", + values_display: undefined, + ...(paramData?.schema && typeof paramData.schema === "object" ? paramData.schema : {}), + }; paramsDict[key] = { - description: "", - schema: { - const: undefined, - description_md: "", - enum: undefined, - examples: undefined, - format: undefined, - items: undefined, - maximum: undefined, - maxLength: undefined, - minimum: undefined, - minLength: undefined, - section: undefined, - title: key, - type: valueType, - values_display: undefined, - }, - value: preloadedHITLParams[key] ?? value, + description, + schema, + value: paramData?.value ?? value, }; }); } diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py index e09ebe76905e7..b2c05e982b8b9 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py @@ -216,7 +216,7 @@ def expected_sample_hitl_detail_dict(sample_ti: TaskInstance) -> dict[str, Any]: "defaults": ["Approve"], "multiple": False, "options": ["Approve", "Reject"], - "params": {"input_1": 1}, + "params": {"input_1": {"value": 1, "schema": {}, "description": None}}, "assigned_users": [], "created_at": mock.ANY, "params_input": {}, @@ -621,7 +621,7 @@ def test_should_respond_200_with_existing_response_and_concrete_query( "body": "this is body 0", "defaults": ["Approve"], "multiple": False, - "params": {"input_1": 1}, + "params": {"input_1": {"value": 1, "schema": {}, "description": None}}, "assigned_users": [], "created_at": DEFAULT_CREATED_AT.isoformat().replace("+00:00", "Z"), "responded_by_user": None, diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index d4ccb21b04b58..36605a7f4abdc 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -2275,7 +2275,7 @@ def test_should_respond_200_with_hitl( "defaults": ["Approve"], "multiple": False, "options": ["Approve", "Reject"], - "params": {"input_1": 1}, + "params": {"input_1": {"value": 1, "description": None, "schema": {}}}, "params_input": {}, "responded_at": None, "responded_by_user": None, @@ -3554,7 +3554,7 @@ def test_should_respond_200_with_hitl( "defaults": ["Approve"], "multiple": False, "options": ["Approve", "Reject"], - "params": {"input_1": 1}, + "params": {"input_1": {"value": 1, "description": None, "schema": {}}}, "params_input": {}, "responded_at": None, "responded_by_user": None, diff --git a/devel-common/src/tests_common/test_utils/version_compat.py b/devel-common/src/tests_common/test_utils/version_compat.py index bf696fe4b8098..ad093637b790a 100644 --- a/devel-common/src/tests_common/test_utils/version_compat.py +++ b/devel-common/src/tests_common/test_utils/version_compat.py @@ -36,6 +36,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) AIRFLOW_V_3_0_3_PLUS = get_base_airflow_version_tuple() >= (3, 0, 3) AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) +AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3) AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) diff --git a/providers/standard/src/airflow/providers/standard/operators/hitl.py b/providers/standard/src/airflow/providers/standard/operators/hitl.py index 5a6b366957932..8b5c0cdd5b68b 100644 --- a/providers/standard/src/airflow/providers/standard/operators/hitl.py +++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py @@ -19,7 +19,7 @@ import logging from airflow.exceptions import AirflowOptionalProviderFeatureException -from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_3_PLUS, AIRFLOW_V_3_1_PLUS if not AIRFLOW_V_3_1_PLUS: raise AirflowOptionalProviderFeatureException("Human in the loop functionality needs Airflow 3.1+.") @@ -84,6 +84,7 @@ def __init__( self.multiple = multiple self.params: ParamsDict = params if isinstance(params, ParamsDict) else ParamsDict(params or {}) + self.notifiers: Sequence[BaseNotifier] = ( [notifiers] if isinstance(notifiers, BaseNotifier) else notifiers or [] ) @@ -110,6 +111,7 @@ def validate_params(self) -> None: Raises: ValueError: If `"_options"` key is present in `params`, which is not allowed. """ + self.params.validate() if "_options" in self.params: raise ValueError('"_options" is not allowed in params') @@ -165,8 +167,10 @@ def execute(self, context: Context): ) @property - def serialized_params(self) -> dict[str, Any]: - return self.params.dump() if isinstance(self.params, ParamsDict) else self.params + def serialized_params(self) -> dict[str, dict[str, Any]]: + if not AIRFLOW_V_3_1_3_PLUS: + return self.params.dump() if isinstance(self.params, ParamsDict) else self.params + return {k: self.params.get_param(k).serialize() for k in self.params} def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: if "error" in event: @@ -196,13 +200,12 @@ def validate_chosen_options(self, chosen_options: list[str]) -> None: def validate_params_input(self, params_input: Mapping) -> None: """Check whether user provide valid params input.""" - if ( - self.serialized_params is not None - and params_input is not None - and set(self.serialized_params.keys()) ^ set(params_input) - ): + if self.params and params_input and set(self.serialized_params.keys()) ^ set(params_input): raise ValueError(f"params_input {params_input} does not match params {self.params}") + for key, value in params_input.items(): + self.params[key] = value + def generate_link_to_ui( self, *, diff --git a/providers/standard/src/airflow/providers/standard/triggers/hitl.py b/providers/standard/src/airflow/providers/standard/triggers/hitl.py index 56384e17af48a..b36a3413dcd8e 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/hitl.py +++ b/providers/standard/src/airflow/providers/standard/triggers/hitl.py @@ -30,6 +30,9 @@ from asgiref.sync import sync_to_async +from airflow.exceptions import ParamValidationError +from airflow.sdk import Param +from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.execution_time.hitl import ( HITLUser, get_hitl_detail_content_detail, @@ -43,7 +46,7 @@ class HITLTriggerEventSuccessPayload(TypedDict, total=False): """Minimum required keys for a success Human-in-the-loop TriggerEvent.""" chosen_options: list[str] - params_input: dict[str, Any] + params_input: dict[str, dict[str, Any]] responded_by_user: HITLUser | None responded_at: datetime timedout: bool @@ -53,7 +56,7 @@ class HITLTriggerEventFailurePayload(TypedDict): """Minimum required keys for a failed Human-in-the-loop TriggerEvent.""" error: str - error_type: Literal["timeout", "unknown"] + error_type: Literal["timeout", "unknown", "validation"] class HITLTrigger(BaseTrigger): @@ -64,7 +67,7 @@ def __init__( *, ti_id: UUID, options: list[str], - params: dict[str, Any], + params: dict[str, dict[str, Any]], defaults: list[str] | None = None, multiple: bool = False, timeout_datetime: datetime | None, @@ -80,7 +83,21 @@ def __init__( self.defaults = defaults self.timeout_datetime = timeout_datetime - self.params = params + self.params = ParamsDict( + { + k: Param( + v.pop("value"), + **v, + ) + if HITLTrigger._is_param(v) + else Param(v) + for k, v in params.items() + }, + ) + + @staticmethod + def _is_param(value: Any) -> bool: + return isinstance(value, dict) and all(key in value for key in ("description", "schema", "value")) def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize HITLTrigger arguments and classpath.""" @@ -90,103 +107,131 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "ti_id": self.ti_id, "options": self.options, "defaults": self.defaults, - "params": self.params, + "params": {k: self.params.get_param(k).serialize() for k in self.params}, "multiple": self.multiple, "timeout_datetime": self.timeout_datetime, "poke_interval": self.poke_interval, }, ) - async def run(self) -> AsyncIterator[TriggerEvent]: - """Loop until the Human-in-the-loop response received or timeout reached.""" - while True: - if self.timeout_datetime and self.timeout_datetime < utcnow(): - # Fetch latest HITL detail before fallback - resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id) - # Response already received, yield success and exit - if resp.response_received and resp.chosen_options: - if TYPE_CHECKING: - assert resp.responded_by_user is not None - assert resp.responded_at is not None - - chosen_options_list = ( - list(resp.chosen_options) if resp.chosen_options is not None else None - ) - self.log.info( - "[HITL] responded_by=%s (id=%s) options=%s at %s (timeout fallback skipped)", - resp.responded_by_user.name, - resp.responded_by_user.id, - chosen_options_list, - resp.responded_at, - ) - yield TriggerEvent( - HITLTriggerEventSuccessPayload( - chosen_options=chosen_options_list, - params_input=resp.params_input or {}, - responded_at=resp.responded_at, - responded_by_user=HITLUser( - id=resp.responded_by_user.id, - name=resp.responded_by_user.name, - ), - timedout=False, - ) - ) - return - - if self.defaults is None: - yield TriggerEvent( - HITLTriggerEventFailurePayload( - error="The timeout has passed, and the response has not yet been received.", - error_type="timeout", - ) - ) - return - - resp = await sync_to_async(update_hitl_detail_response)( - ti_id=self.ti_id, - chosen_options=self.defaults, - params_input=self.params, + async def _handle_timeout(self) -> TriggerEvent: + """Handle HITL timeout logic and yield appropriate event.""" + resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id) + + # Case 1: Response arrived just before timeout + if resp.response_received and resp.chosen_options: + if TYPE_CHECKING: + assert resp.responded_by_user is not None + assert resp.responded_at is not None + + chosen_options_list = list(resp.chosen_options or []) + self.log.info( + "[HITL] responded_by=%s (id=%s) options=%s at %s (timeout fallback skipped)", + resp.responded_by_user.name, + resp.responded_by_user.id, + chosen_options_list, + resp.responded_at, + ) + return TriggerEvent( + HITLTriggerEventSuccessPayload( + chosen_options=chosen_options_list, + params_input=resp.params_input or {}, + responded_at=resp.responded_at, + responded_by_user=HITLUser( + id=resp.responded_by_user.id, + name=resp.responded_by_user.name, + ), + timedout=False, ) - if TYPE_CHECKING: - assert resp.responded_at is not None - self.log.info( - "[HITL] timeout reached before receiving response, fallback to default %s", self.defaults + ) + + # Case 2: No defaults defined → failure + if self.defaults is None: + return TriggerEvent( + HITLTriggerEventFailurePayload( + error="The timeout has passed, and the response has not yet been received.", + error_type="timeout", ) - yield TriggerEvent( - HITLTriggerEventSuccessPayload( - chosen_options=self.defaults, - params_input=self.params, - responded_by_user=None, - responded_at=resp.responded_at, - timedout=True, + ) + + # Case 3: Timeout fallback to default + resp = await sync_to_async(update_hitl_detail_response)( + ti_id=self.ti_id, + chosen_options=self.defaults, + params_input=self.params.dump(), + ) + if TYPE_CHECKING: + assert resp.responded_at is not None + + self.log.info( + "[HITL] timeout reached before receiving response, fallback to default %s", + self.defaults, + ) + return TriggerEvent( + HITLTriggerEventSuccessPayload( + chosen_options=self.defaults, + params_input=self.params.dump(), + responded_by_user=None, + responded_at=resp.responded_at, + timedout=True, + ) + ) + + async def _handle_response(self): + """Check if HITL response is ready and yield success if so.""" + resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id) + if TYPE_CHECKING: + assert resp.responded_by_user is not None + assert resp.responded_at is not None + + if not (resp.response_received and resp.chosen_options): + return None + + # validate input + if params_input := resp.params_input: + try: + for key, value in params_input.items(): + self.params[key] = value + except ParamValidationError as err: + return TriggerEvent( + HITLTriggerEventFailurePayload( + error=str(err), + error_type="validation", ) ) + + chosen_options_list = list(resp.chosen_options or []) + self.log.info( + "[HITL] responded_by=%s (id=%s) options=%s at %s", + resp.responded_by_user.name, + resp.responded_by_user.id, + chosen_options_list, + resp.responded_at, + ) + return TriggerEvent( + HITLTriggerEventSuccessPayload( + chosen_options=chosen_options_list, + params_input=params_input or {}, + responded_at=resp.responded_at, + responded_by_user=HITLUser( + id=resp.responded_by_user.id, + name=resp.responded_by_user.name, + ), + timedout=False, + ) + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Loop until the Human-in-the-loop response received or timeout reached.""" + while True: + if self.timeout_datetime and self.timeout_datetime < utcnow(): + event = await self._handle_timeout() + yield event return - resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id) - if resp.response_received and resp.chosen_options: - if TYPE_CHECKING: - assert resp.responded_by_user is not None - assert resp.responded_at is not None - chosen_options_list = list(resp.chosen_options) if resp.chosen_options is not None else None - self.log.info( - "[HITL] responded_by=%s (id=%s) options=%s at %s", - resp.responded_by_user.name, - resp.responded_by_user.id, - chosen_options_list, - resp.responded_at, - ) - yield TriggerEvent( - HITLTriggerEventSuccessPayload( - chosen_options=chosen_options_list, - params_input=resp.params_input or {}, - responded_at=resp.responded_at, - responded_by_user=HITLUser( - id=resp.responded_by_user.id, - name=resp.responded_by_user.name, - ), - timedout=False, - ) - ) + event = await self._handle_response() + if event: + yield event return + await asyncio.sleep(self.poke_interval) diff --git a/providers/standard/src/airflow/providers/standard/version_compat.py b/providers/standard/src/airflow/providers/standard/version_compat.py index 3539d070dab58..5316156bc03db 100644 --- a/providers/standard/src/airflow/providers/standard/version_compat.py +++ b/providers/standard/src/airflow/providers/standard/version_compat.py @@ -34,6 +34,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0) AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) +AIRFLOW_V_3_1_3_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 3) AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0) # BaseOperator: Use 3.1+ due to xcom_push method missing in SDK BaseOperator 3.0.x diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py b/providers/standard/tests/unit/standard/operators/test_hitl.py index 2db7ef86934d4..18984cb8912af 100644 --- a/providers/standard/tests/unit/standard/operators/test_hitl.py +++ b/providers/standard/tests/unit/standard/operators/test_hitl.py @@ -18,8 +18,6 @@ import pytest -from airflow.providers.standard.exceptions import HITLRejectException, HITLTimeoutError, HITLTriggerEventError - from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS if not AIRFLOW_V_3_1_PLUS: @@ -33,9 +31,10 @@ import pytest from sqlalchemy import select -from airflow.exceptions import AirflowException, DownstreamTasksSkipped +from airflow.exceptions import AirflowException, DownstreamTasksSkipped, ParamValidationError from airflow.models import TaskInstance, Trigger from airflow.models.hitl import HITLDetail +from airflow.providers.standard.exceptions import HITLRejectException, HITLTimeoutError, HITLTriggerEventError from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.hitl import ( ApprovalOperator, @@ -46,9 +45,9 @@ from airflow.sdk import Param, timezone from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.execution_time.hitl import HITLUser -from airflow.utils.context import Context from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_3_PLUS if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -137,9 +136,27 @@ def test_validate_options_with_empty_options(self) -> None: params=ParamsDict({"input_1": 1}), ) - def test_validate_params_with__options(self) -> None: + @pytest.mark.parametrize( + ("params", "exc", "error_msg"), + ( + (ParamsDict({"_options": 1}), ValueError, '"_options" is not allowed in params'), + ( + ParamsDict({"param": Param("", type="integer")}), + ParamValidationError, + ( + "Invalid input for param param: '' is not of type 'integer'\n\n" + "Failed validating 'type' in schema:\n" + " {'type': 'integer'}\n\n" + "On instance:\n ''" + ), + ), + ), + ) + def test_validate_params( + self, params: ParamsDict, exc: type[ValueError | ParamValidationError], error_msg: str + ) -> None: # validate_params is called during initialization - with pytest.raises(ValueError, match='"_options" is not allowed in params'): + with pytest.raises(exc, match=error_msg): HITLOperator( task_id="hitl_test", subject="This is subject", @@ -147,7 +164,7 @@ def test_validate_params_with__options(self) -> None: body="This is body", defaults=["1"], multiple=False, - params=ParamsDict({"_options": 1}), + params=params, ) def test_validate_defaults(self) -> None: @@ -218,12 +235,21 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None: assert hitl_detail_model.body == "This is body" assert hitl_detail_model.defaults == ["1"] assert hitl_detail_model.multiple is False - assert hitl_detail_model.params == {"input_1": 1} assert hitl_detail_model.assignees == [{"id": "test", "name": "test"}] assert hitl_detail_model.responded_at is None assert hitl_detail_model.responded_by is None assert hitl_detail_model.chosen_options is None assert hitl_detail_model.params_input == {} + if AIRFLOW_V_3_1_3_PLUS: + assert hitl_detail_model.params == { + "input_1": { + "value": 1, + "description": None, + "schema": {}, + } + } + else: + assert hitl_detail_model.params == {"input_1": 1} assert notifier.called is True @@ -235,17 +261,55 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None: "ti_id": ti.id, "options": ["1", "2", "3", "4", "5"], "defaults": ["1"], - "params": {"input_1": 1}, + "params": { + "input_1": { + "value": 1, + "description": None, + "schema": {}, + } + }, "multiple": False, "timeout_datetime": None, "poke_interval": 5.0, } + @pytest.mark.skipif(not AIRFLOW_V_3_1_3_PLUS, reason="This only works in airflow-core >= 3.1.3") @pytest.mark.parametrize( ("input_params", "expected_params"), [ - (ParamsDict({"input": 1}), {"input": 1}), - ({"input": Param(5, type="integer", minimum=3)}, {"input": 5}), + ( + ParamsDict({"input": 1}), + { + "input": { + "description": None, + "schema": {}, + "value": 1, + }, + }, + ), + ( + {"input": Param(5, type="integer", minimum=3, description="test")}, + { + "input": { + "value": 5, + "schema": { + "minimum": 3, + "type": "integer", + }, + "description": "test", + } + }, + ), + ( + {"input": 1}, + { + "input": { + "value": 1, + "schema": {}, + "description": None, + } + }, + ), (None, {}), ], ) @@ -261,6 +325,20 @@ def test_serialzed_params( ) assert hitl_op.serialized_params == expected_params + @pytest.mark.skipif( + AIRFLOW_V_3_1_3_PLUS, + reason="Preserve the old behavior if airflow-core < 3.1.3. Otherwise the UI will break.", + ) + def test_serialzed_params_legacy(self) -> None: + hitl_op = HITLOperator( + task_id="hitl_test", + subject="This is subject", + body="This is body", + options=["1", "2", "3", "4", "5"], + params={"input": Param(1)}, + ) + assert hitl_op.serialized_params == {"input": 1} + def test_execute_complete(self) -> None: hitl_op = HITLOperator( task_id="hitl_test", @@ -330,21 +408,47 @@ def test_validate_chosen_options_with_invalid_content(self) -> None: }, ) - def test_validate_params_input_with_invalid_input(self) -> None: + @pytest.mark.parametrize( + ("params", "params_input", "exc", "error_msg"), + ( + ( + ParamsDict({"input": 1}), + {"no such key": 2, "input": 333}, + ValueError, + "params_input {'no such key': 2, 'input': 333} does not match params {'input': 1}", + ), + ( + ParamsDict({"input": Param(3, type="number", minimum=3)}), + {"input": 0}, + ParamValidationError, + ( + "Invalid input for param input: 0 is less than the minimum of 3\n\n" + "Failed validating 'minimum' in schema:\n.*" + ), + ), + ), + ) + def test_validate_params_input_with_invalid_input( + self, + params: ParamsDict, + params_input: dict[str, Any], + exc: type[ValueError | ParamValidationError], + error_msg: str, + ) -> None: hitl_op = HITLOperator( task_id="hitl_test", subject="This is subject", body="This is body", options=["1", "2", "3", "4", "5"], - params={"input": 1}, + params=params, ) - with pytest.raises(ValueError, match="no such key"): + with pytest.raises(exc, match=error_msg): hitl_op.execute_complete( context={}, event={ "chosen_options": ["1"], - "params_input": {"no such key": 2, "input": 333}, + "params_input": params_input, "responded_by_user": {"id": "test", "name": "test"}, }, ) diff --git a/providers/standard/tests/unit/standard/triggers/test_hitl.py b/providers/standard/tests/unit/standard/triggers/test_hitl.py index 1441d04bfc0ba..adb82ff00c056 100644 --- a/providers/standard/tests/unit/standard/triggers/test_hitl.py +++ b/providers/standard/tests/unit/standard/triggers/test_hitl.py @@ -17,6 +17,8 @@ from __future__ import annotations +from typing import Any + import pytest from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS @@ -40,16 +42,22 @@ from airflow.triggers.base import TriggerEvent TI_ID = uuid7() -default_trigger_args = { - "ti_id": TI_ID, - "options": ["1", "2", "3", "4", "5"], - "params": {"input": 1}, - "multiple": False, -} + + +@pytest.fixture +def default_trigger_args() -> dict[str, Any]: + return { + "ti_id": TI_ID, + "options": ["1", "2", "3", "4", "5"], + "params": { + "input": {"value": 1, "schema": {}, "description": None}, + }, + "multiple": False, + } class TestHITLTrigger: - def test_serialization(self): + def test_serialization(self, default_trigger_args): trigger = HITLTrigger( defaults=["1"], timeout_datetime=None, @@ -61,7 +69,7 @@ def test_serialization(self): assert kwargs == { "ti_id": TI_ID, "options": ["1", "2", "3", "4", "5"], - "params": {"input": 1}, + "params": {"input": {"value": 1, "description": None, "schema": {}}}, "defaults": ["1"], "multiple": False, "timeout_datetime": None, @@ -71,7 +79,7 @@ def test_serialization(self): @pytest.mark.db_test @pytest.mark.asyncio @mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response") - async def test_run_failed_due_to_timeout(self, mock_update, mock_supervisor_comms): + async def test_run_failed_due_to_timeout(self, mock_update, mock_supervisor_comms, default_trigger_args): trigger = HITLTrigger( timeout_datetime=utcnow() + timedelta(seconds=0.1), poke_interval=5, @@ -100,7 +108,9 @@ async def test_run_failed_due_to_timeout(self, mock_update, mock_supervisor_comm @pytest.mark.asyncio @mock.patch.object(HITLTrigger, "log") @mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response") - async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_log, mock_supervisor_comms): + async def test_run_fallback_to_default_due_to_timeout( + self, mock_update, mock_log, mock_supervisor_comms, default_trigger_args + ): trigger = HITLTrigger( defaults=["1"], timeout_datetime=utcnow() + timedelta(seconds=0.1), @@ -139,7 +149,7 @@ async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_lo @mock.patch.object(HITLTrigger, "log") @mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response") async def test_run_should_check_response_in_timeout_handler( - self, mock_update, mock_log, mock_supervisor_comms + self, mock_update, mock_log, mock_supervisor_comms, default_trigger_args ): # action time only slightly before timeout action_datetime = utcnow() + timedelta(seconds=0.1) @@ -186,7 +196,9 @@ async def test_run_should_check_response_in_timeout_handler( @pytest.mark.asyncio @mock.patch.object(HITLTrigger, "log") @mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response") - async def test_run(self, mock_update, mock_log, mock_supervisor_comms, time_machine): + async def test_run( + self, mock_update, mock_log, mock_supervisor_comms, time_machine, default_trigger_args + ): time_machine.move_to(datetime(2025, 7, 29, 2, 0, 0)) trigger = HITLTrigger( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 9867c67535a7e..9f841d4c38ec0 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -758,7 +758,7 @@ def add_response( body: str | None = None, defaults: list[str] | None = None, multiple: bool = False, - params: dict[str, Any] | None = None, + params: dict[str, dict[str, Any]] | None = None, assigned_users: list[HITLUser] | None = None, ) -> HITLDetailRequest: """Add a Human-in-the-loop response that waits for human response for a specific Task Instance.""" diff --git a/task-sdk/src/airflow/sdk/definitions/param.py b/task-sdk/src/airflow/sdk/definitions/param.py index 5da589d79e991..2c853ce1ffc90 100644 --- a/task-sdk/src/airflow/sdk/definitions/param.py +++ b/task-sdk/src/airflow/sdk/definitions/param.py @@ -137,7 +137,6 @@ class ParamsDict(MutableMapping[str, Any]): if they are not already. This class is to replace param's dictionary implicitly and ideally not needed to be used directly. - :param dict_obj: A dict or dict like object to init ParamsDict :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict """ diff --git a/task-sdk/tests/task_sdk/execution_time/test_hitl.py b/task-sdk/tests/task_sdk/execution_time/test_hitl.py index cd3682d922efe..5eb2dc7dab4d1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_hitl.py +++ b/task-sdk/tests/task_sdk/execution_time/test_hitl.py @@ -39,7 +39,7 @@ def test_upsert_hitl_detail(mock_supervisor_comms) -> None: subject="Subject", body="Optional body", defaults=["Approve", "Reject"], - params={"input_1": 1}, + params={"input_1": {"value": 1, "description": None, "schema": {}}}, assigned_users=[HITLUser(id="test", name="test")], multiple=False, ) @@ -50,7 +50,7 @@ def test_upsert_hitl_detail(mock_supervisor_comms) -> None: subject="Subject", body="Optional body", defaults=["Approve", "Reject"], - params={"input_1": 1}, + params={"input_1": {"value": 1, "description": None, "schema": {}}}, assigned_users=[APIHITLUser(id="test", name="test")], multiple=False, )