Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class BaseHITLDetail(BaseModel):
body: str | None = None
defaults: list[str] | None = None
multiple: bool = False
params: dict[str, Any] = Field(default_factory=dict)
params: Mapping = Field(default_factory=dict)
assigned_users: list[HITLUser] = Field(default_factory=list)
created_at: datetime

Expand All @@ -74,7 +74,20 @@ class BaseHITLDetail(BaseModel):
@classmethod
def get_params(cls, params: dict[str, Any]) -> dict[str, Any]:
"""Convert params attribute to dict representation."""
return {k: v.dump() if getattr(v, "dump", None) else v for k, v in params.items()}
return {
key: value
if BaseHITLDetail._is_param(value)
else {
"value": value,
"description": None,
"schema": {},
}
for key, value in params.items()
}

@staticmethod
def _is_param(value: Any) -> bool:
return isinstance(value, dict) and all(key in value for key in ("description", "schema", "value"))


class HITLDetail(BaseHITLDetail):
Expand Down
51 changes: 30 additions & 21 deletions airflow-core/src/airflow/ui/src/utils/hitl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import type { TFunction } from "i18next";

import type { HITLDetail } from "openapi/requests/types.gen";
import type { ParamsSpec } from "src/queries/useDagParams";
import type { ParamSchema, ParamsSpec } from "src/queries/useDagParams";

export type HITLResponseParams = {
chosen_options?: Array<string>;
Expand Down Expand Up @@ -70,7 +70,7 @@ export const getHITLParamsDict = (
searchParams: URLSearchParams,
): ParamsSpec => {
const paramsDict: ParamsSpec = {};
const { preloadedHITLOptions, preloadedHITLParams } = getPreloadHITLFormData(searchParams, hitlDetail);
const { preloadedHITLOptions } = getPreloadHITLFormData(searchParams, hitlDetail);
const isApprovalTask =
hitlDetail.options.includes("Approve") &&
hitlDetail.options.includes("Reject") &&
Expand Down Expand Up @@ -108,27 +108,36 @@ export const getHITLParamsDict = (
const sourceParams = hitlDetail.response_received ? hitlDetail.params_input : hitlDetail.params;

Object.entries(sourceParams ?? {}).forEach(([key, value]) => {
const valueType = typeof value === "number" ? "number" : "string";
if (!hitlDetail.params) {
return;
}
const paramData = hitlDetail.params[key] as ParamsSpec | undefined;

const description: string =
paramData && typeof paramData.description === "string" ? paramData.description : "";

const schema: ParamSchema = {
const: undefined,
description_md: "",
enum: undefined,
examples: undefined,
format: undefined,
items: undefined,
maximum: undefined,
maxLength: undefined,
minimum: undefined,
minLength: undefined,
section: undefined,
title: key,
type: typeof value === "number" ? "number" : "string",
values_display: undefined,
...(paramData?.schema && typeof paramData.schema === "object" ? paramData.schema : {}),
};

paramsDict[key] = {
description: "",
schema: {
const: undefined,
description_md: "",
enum: undefined,
examples: undefined,
format: undefined,
items: undefined,
maximum: undefined,
maxLength: undefined,
minimum: undefined,
minLength: undefined,
section: undefined,
title: key,
type: valueType,
values_display: undefined,
},
value: preloadedHITLParams[key] ?? value,
description,
schema,
value: paramData?.value ?? value,
};
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def expected_sample_hitl_detail_dict(sample_ti: TaskInstance) -> dict[str, Any]:
"defaults": ["Approve"],
"multiple": False,
"options": ["Approve", "Reject"],
"params": {"input_1": 1},
"params": {"input_1": {"value": 1, "schema": {}, "description": None}},
"assigned_users": [],
"created_at": mock.ANY,
"params_input": {},
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_should_respond_200_with_existing_response_and_concrete_query(
"body": "this is body 0",
"defaults": ["Approve"],
"multiple": False,
"params": {"input_1": 1},
"params": {"input_1": {"value": 1, "schema": {}, "description": None}},
"assigned_users": [],
"created_at": DEFAULT_CREATED_AT.isoformat().replace("+00:00", "Z"),
"responded_by_user": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2275,7 +2275,7 @@ def test_should_respond_200_with_hitl(
"defaults": ["Approve"],
"multiple": False,
"options": ["Approve", "Reject"],
"params": {"input_1": 1},
"params": {"input_1": {"value": 1, "description": None, "schema": {}}},
"params_input": {},
"responded_at": None,
"responded_by_user": None,
Expand Down Expand Up @@ -3554,7 +3554,7 @@ def test_should_respond_200_with_hitl(
"defaults": ["Approve"],
"multiple": False,
"options": ["Approve", "Reject"],
"params": {"input_1": 1},
"params": {"input_1": {"value": 1, "description": None, "schema": {}}},
"params_input": {},
"responded_at": None,
"responded_by_user": None,
Expand Down
1 change: 1 addition & 0 deletions devel-common/src/tests_common/test_utils/version_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_0_3_PLUS = get_base_airflow_version_tuple() >= (3, 0, 3)
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_3_PLUS, AIRFLOW_V_3_1_PLUS

if not AIRFLOW_V_3_1_PLUS:
raise AirflowOptionalProviderFeatureException("Human in the loop functionality needs Airflow 3.1+.")
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
self.multiple = multiple

self.params: ParamsDict = params if isinstance(params, ParamsDict) else ParamsDict(params or {})

self.notifiers: Sequence[BaseNotifier] = (
[notifiers] if isinstance(notifiers, BaseNotifier) else notifiers or []
)
Expand All @@ -110,6 +111,7 @@ def validate_params(self) -> None:
Raises:
ValueError: If `"_options"` key is present in `params`, which is not allowed.
"""
self.params.validate()
if "_options" in self.params:
raise ValueError('"_options" is not allowed in params')

Expand Down Expand Up @@ -165,8 +167,10 @@ def execute(self, context: Context):
)

@property
def serialized_params(self) -> dict[str, Any]:
return self.params.dump() if isinstance(self.params, ParamsDict) else self.params
def serialized_params(self) -> dict[str, dict[str, Any]]:
if not AIRFLOW_V_3_1_3_PLUS:
return self.params.dump() if isinstance(self.params, ParamsDict) else self.params
return {k: self.params.get_param(k).serialize() for k in self.params}

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
if "error" in event:
Expand Down Expand Up @@ -196,13 +200,12 @@ def validate_chosen_options(self, chosen_options: list[str]) -> None:

def validate_params_input(self, params_input: Mapping) -> None:
"""Check whether user provide valid params input."""
if (
self.serialized_params is not None
and params_input is not None
and set(self.serialized_params.keys()) ^ set(params_input)
):
if self.params and params_input and set(self.serialized_params.keys()) ^ set(params_input):
raise ValueError(f"params_input {params_input} does not match params {self.params}")

for key, value in params_input.items():
self.params[key] = value

def generate_link_to_ui(
self,
*,
Expand Down
Loading
Loading