Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class UpdateHITLDetailPayload(BaseModel):
"""Schema for updating the content of a Human-in-the-loop detail."""

chosen_options: list[str]
chosen_options: list[str] = Field(min_length=1)
params_input: Mapping = Field(default_factory=dict)


Expand All @@ -39,7 +39,7 @@ class HITLDetailResponse(BaseModel):

user_id: str
response_at: datetime
chosen_options: list[str]
chosen_options: list[str] = Field(min_length=1)
params_input: Mapping = Field(default_factory=dict)


Expand All @@ -49,7 +49,7 @@ class HITLDetail(BaseModel):
task_instance: TaskInstanceResponse

# User Request Detail
options: list[str]
options: list[str] = Field(min_length=1)
subject: str
body: str | None = None
defaults: list[str] | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10219,6 +10219,7 @@ components:
items:
type: string
type: array
minItems: 1
title: Options
subject:
type: string
Expand Down Expand Up @@ -10305,6 +10306,7 @@ components:
items:
type: string
type: array
minItems: 1
title: Chosen Options
params_input:
additionalProperties: true
Expand Down Expand Up @@ -11855,6 +11857,7 @@ components:
items:
type: string
type: array
minItems: 1
title: Chosen Options
params_input:
additionalProperties: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class HITLDetailRequest(BaseModel):
"""Schema for the request part of a Human-in-the-loop detail for a specific task instance."""

ti_id: UUID
options: list[str]
options: list[str] = Field(min_length=1)
subject: str
body: str | None = None
defaults: list[str] | None = None
Expand All @@ -42,7 +42,7 @@ class UpdateHITLDetailPayload(BaseModel):
"""Schema for writing the response part of a Human-in-the-loop detail for a specific task instance."""

ti_id: UUID
chosen_options: list[str]
chosen_options: list[str] = Field(min_length=1)
params_input: dict[str, Any] = Field(default_factory=dict)


Expand All @@ -52,6 +52,7 @@ class HITLDetailResponse(BaseModel):
response_received: bool
user_id: str | None
response_at: datetime | None
# It's empty if the user has not yet responded.
chosen_options: list[str] | None
params_input: dict[str, Any] = Field(default_factory=dict)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3372,6 +3372,7 @@ export const $HITLDetail = {
type: 'string'
},
type: 'array',
minItems: 1,
title: 'Options'
},
subject: {
Expand Down Expand Up @@ -3503,6 +3504,7 @@ export const $HITLDetailResponse = {
type: 'string'
},
type: 'array',
minItems: 1,
title: 'Chosen Options'
},
params_input: {
Expand Down Expand Up @@ -5819,6 +5821,7 @@ export const $UpdateHITLDetailPayload = {
type: 'string'
},
type: 'array',
minItems: 1,
title: 'Chosen Options'
},
params_input: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,22 @@ def test_should_respond_200_with_existing_response(
"response_at": "2025-07-03T00:00:00Z",
}

def test_should_respond_401(
self,
unauthenticated_test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthenticated_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}")
assert response.status_code == 401

def test_should_respond_403(
self,
unauthorized_test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthorized_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}")
assert response.status_code == 403

def test_should_respond_404(
self,
test_client: TestClient,
Expand Down Expand Up @@ -288,21 +304,18 @@ def test_should_respond_409(
)
}

def test_should_respond_401(
@pytest.mark.usefixtures("sample_hitl_detail")
def test_should_respond_422_with_empty_option(
self,
unauthenticated_test_client: TestClient,
test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthenticated_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}")
assert response.status_code == 401
response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}",
json={"chosen_options": [], "params_input": {"input_1": 2}},
)

def test_should_respond_403(
self,
unauthorized_test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthorized_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}")
assert response.status_code == 403
assert response.status_code == 422


class TestUpdateMappedTIHITLDetail:
Expand All @@ -326,6 +339,22 @@ def test_should_respond_200_with_existing_response(
"response_at": "2025-07-03T00:00:00Z",
}

def test_should_respond_401(
self,
unauthenticated_test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthenticated_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}/-1")
assert response.status_code == 401

def test_should_respond_403(
self,
unauthorized_test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthorized_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}/-1")
assert response.status_code == 403

def test_should_respond_404(
self,
test_client: TestClient,
Expand Down Expand Up @@ -371,21 +400,18 @@ def test_should_respond_409(
)
}

def test_should_respond_401(
@pytest.mark.usefixtures("sample_hitl_detail")
def test_should_respond_422_with_empty_option(
self,
unauthenticated_test_client: TestClient,
test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthenticated_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}/-1")
assert response.status_code == 401
response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}/-1",
json={"chosen_options": [], "params_input": {"input_1": 2}},
)

def test_should_respond_403(
self,
unauthorized_test_client: TestClient,
sample_ti_url_identifier: str,
) -> None:
response = unauthorized_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}/-1")
assert response.status_code == 403
assert response.status_code == 422


class TestGetHITLDetailEndpoint:
Expand All @@ -400,16 +426,6 @@ def test_should_respond_200_with_existing_response(
assert response.status_code == 200
assert response.json() == expected_sample_hitl_detail_dict

def test_should_respond_404(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
expected_ti_not_found_error_msg: str,
) -> None:
response = test_client.get(f"/hitlDetails/{sample_ti_url_identifier}")
assert response.status_code == 404
assert response.json() == {"detail": expected_ti_not_found_error_msg}

def test_should_respond_401(
self,
unauthenticated_test_client: TestClient,
Expand All @@ -426,6 +442,16 @@ def test_should_respond_403(
response = unauthorized_test_client.get(f"/hitlDetails/{sample_ti_url_identifier}")
assert response.status_code == 403

def test_should_respond_404(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
expected_ti_not_found_error_msg: str,
) -> None:
response = test_client.get(f"/hitlDetails/{sample_ti_url_identifier}")
assert response.status_code == 404
assert response.json() == {"detail": expected_ti_not_found_error_msg}


class TestGetMappedTIHITLDetail:
@pytest.mark.usefixtures("sample_hitl_detail")
Expand All @@ -439,16 +465,6 @@ def test_should_respond_200_with_existing_response(
assert response.status_code == 200
assert response.json() == expected_sample_hitl_detail_dict

def test_should_respond_404(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
expected_mapped_ti_not_found_error_msg: str,
) -> None:
response = test_client.get(f"/hitlDetails/{sample_ti_url_identifier}/-1")
assert response.status_code == 404
assert response.json() == {"detail": expected_mapped_ti_not_found_error_msg}

def test_should_respond_401(
self,
unauthenticated_test_client: TestClient,
Expand Down Expand Up @@ -542,3 +558,13 @@ def test_should_respond_401(self, unauthenticated_test_client: TestClient) -> No
def test_should_respond_403(self, unauthorized_test_client: TestClient) -> None:
response = unauthorized_test_client.get("/hitlDetails/")
assert response.status_code == 403

def test_should_respond_404(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
expected_mapped_ti_not_found_error_msg: str,
) -> None:
response = test_client.get(f"/hitlDetails/{sample_ti_url_identifier}/-1")
assert response.status_code == 404
assert response.json() == {"detail": expected_mapped_ti_not_found_error_msg}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,29 @@ def test_upsert_hitl_detail(
}


def test_upsert_hitl_detail_with_empty_option(
client: TestClient,
create_task_instance: CreateTaskInstance,
session: Session,
) -> None:
ti = create_task_instance()
session.commit()

response = client.post(
f"/execution/hitlDetails/{ti.id}",
json={
"ti_id": ti.id,
"subject": "This is subject",
"body": "this is body",
"options": [],
"defaults": ["Approve"],
"multiple": False,
"params": {"input_1": 1},
},
)
assert response.status_code == 422


@time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False)
@pytest.mark.usefixtures("sample_hitl_detail")
def test_update_hitl_detail(client: Client, sample_ti: TaskInstance) -> None:
Expand All @@ -150,6 +173,18 @@ def test_update_hitl_detail(client: Client, sample_ti: TaskInstance) -> None:
}


def test_update_hitl_detail_without_option(client: Client, sample_ti: TaskInstance) -> None:
response = client.patch(
f"/execution/hitlDetails/{sample_ti.id}",
json={
"ti_id": sample_ti.id,
"chosen_options": [],
"params_input": {"input_1": 2},
},
)
assert response.status_code == 422


def test_update_hitl_detail_without_ti(client: Client) -> None:
ti_id = str(uuid7())
response = client.patch(
Expand Down
6 changes: 3 additions & 3 deletions airflow-ctl/src/airflowctl/api/datamodels/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class HITLDetailResponse(BaseModel):

user_id: Annotated[str, Field(title="User Id")]
response_at: Annotated[datetime, Field(title="Response At")]
chosen_options: Annotated[list[str], Field(title="Chosen Options")]
chosen_options: Annotated[list[str], Field(min_length=1, title="Chosen Options")]
params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None


Expand Down Expand Up @@ -927,7 +927,7 @@ class UpdateHITLDetailPayload(BaseModel):
Schema for updating the content of a Human-in-the-loop detail.
"""

chosen_options: Annotated[list[str], Field(title="Chosen Options")]
chosen_options: Annotated[list[str], Field(min_length=1, title="Chosen Options")]
params_input: Annotated[dict[str, Any] | None, Field(title="Params Input")] = None


Expand Down Expand Up @@ -1818,7 +1818,7 @@ class HITLDetail(BaseModel):
"""

task_instance: TaskInstanceResponse
options: Annotated[list[str], Field(title="Options")]
options: Annotated[list[str], Field(min_length=1, title="Options")]
subject: Annotated[str, Field(title="Subject")]
body: Annotated[str | None, Field(title="Body")] = None
defaults: Annotated[list[str] | None, Field(title="Defaults")] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ def __init__(
[notifiers] if isinstance(notifiers, BaseNotifier) else notifiers or []
)

self.validate_options()
self.validate_defaults()

def validate_options(self) -> None:
if not self.options:
raise ValueError('"options" cannot be empty.')

def validate_defaults(self) -> None:
"""
Validate whether the given defaults pass the following criteria.
Expand Down
24 changes: 24 additions & 0 deletions providers/standard/tests/unit/standard/operators/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@


class TestHITLOperator:
def test_validate_options(self) -> None:
hitl_op = HITLOperator(
task_id="hitl_test",
subject="This is subject",
options=["1", "2", "3", "4", "5"],
body="This is body",
defaults=["1"],
multiple=False,
params=ParamsDict({"input_1": 1}),
)
hitl_op.validate_defaults()

def test_validate_options_with_empty_options(self) -> None:
with pytest.raises(ValueError, match='"options" cannot be empty.'):
HITLOperator(
task_id="hitl_test",
subject="This is subject",
options=[],
body="This is body",
defaults=["1"],
multiple=False,
params=ParamsDict({"input_1": 1}),
)

def test_validate_defaults(self) -> None:
hitl_op = HITLOperator(
task_id="hitl_test",
Expand Down
Loading
Loading