diff --git a/airflow-ctl/src/airflowctl/api/operations.py b/airflow-ctl/src/airflowctl/api/operations.py index 7350b1aa2d657..0161f566ff790 100644 --- a/airflow-ctl/src/airflowctl/api/operations.py +++ b/airflow-ctl/src/airflowctl/api/operations.py @@ -33,7 +33,6 @@ BackfillCollectionResponse, BackfillPostBody, BackfillResponse, - BulkActionResponse, BulkBodyConnectionBody, BulkBodyPoolBody, BulkBodyVariableBody, @@ -382,11 +381,11 @@ def create( except ServerResponseError as e: raise e - def bulk(self, connections: BulkBodyConnectionBody) -> BulkActionResponse | ServerResponseError: + def bulk(self, connections: BulkBodyConnectionBody) -> BulkResponse | ServerResponseError: """CRUD multiple connections.""" try: self.response = self.client.patch("connections", json=connections.model_dump()) - return BulkActionResponse.model_validate_json(self.response.content) + return BulkResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e @@ -609,11 +608,11 @@ def create(self, pool: PoolBody) -> PoolResponse | ServerResponseError: except ServerResponseError as e: raise e - def bulk(self, pools: BulkBodyPoolBody) -> BulkActionResponse | ServerResponseError: + def bulk(self, pools: BulkBodyPoolBody) -> BulkResponse | ServerResponseError: """CRUD multiple pools.""" try: self.response = self.client.patch("pools", json=pools.model_dump()) - return BulkActionResponse.model_validate_json(self.response.content) + return BulkResponse.model_validate_json(self.response.content) except ServerResponseError as e: raise e diff --git a/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py b/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py index 4cdf889bf0aae..bc0009e339147 100644 --- a/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py +++ b/airflow-ctl/src/airflowctl/ctl/commands/pool_command.py @@ -34,20 +34,20 @@ @provide_api_client(kind=ClientKind.CLI) -def import_(args, api_client: Client = NEW_API_CLIENT): +def import_(args, api_client: Client = NEW_API_CLIENT) -> None: """Import pools from file.""" filepath = Path(args.file) if not filepath.exists(): raise SystemExit(f"Missing pools file {args.file}") - success, failed = _import_helper(api_client, filepath) - if failed: - raise SystemExit(f"Failed to update pool(s): {', '.join(failed)}") + success, errors = _import_helper(api_client, filepath) + if errors: + raise SystemExit(f"Failed to update pool(s): {errors}") rich.print(success) @provide_api_client(kind=ClientKind.CLI) -def export(args, api_client: Client = NEW_API_CLIENT): +def export(args, api_client: Client = NEW_API_CLIENT) -> None: """ Export all pools. @@ -119,4 +119,4 @@ def _import_helper(api_client: Client, filepath: Path): ) result = api_client.pools.bulk(pools=bulk_body) # Return the successful and failed entities directly from the response - return result.success, result.errors + return result.create.success, result.create.errors diff --git a/airflow-ctl/tests/airflow_ctl/api/test_operations.py b/airflow-ctl/tests/airflow_ctl/api/test_operations.py index 9d4588bdde0bf..1a894990183a0 100644 --- a/airflow-ctl/tests/airflow_ctl/api/test_operations.py +++ b/airflow-ctl/tests/airflow_ctl/api/test_operations.py @@ -487,9 +487,10 @@ class TestConnectionsOperations: ] ) - connection_bulk_action_response = BulkActionResponse( - success=[connection_id], - errors=[], + connection_bulk_response = BulkResponse( + create=BulkActionResponse(success=[connection_id], errors=[]), + update=None, + delete=None, ) def test_get(self): @@ -522,13 +523,11 @@ def handle_request(request: httpx.Request) -> httpx.Response: def test_bulk(self): def handle_request(request: httpx.Request) -> httpx.Response: assert request.url.path == "/api/v2/connections" - return httpx.Response( - 200, json=json.loads(self.connection_bulk_action_response.model_dump_json()) - ) + return httpx.Response(200, json=json.loads(self.connection_bulk_response.model_dump_json())) client = make_api_client(transport=httpx.MockTransport(handle_request)) response = client.connections.bulk(connections=self.connection_bulk_body) - assert response == self.connection_bulk_action_response + assert response == self.connection_bulk_response def test_delete(self): def handle_request(request: httpx.Request) -> httpx.Response: @@ -954,9 +953,10 @@ class TestPoolsOperations: pools=[pool_response], total_entries=1, ) - pool_bulk_action_response = BulkActionResponse( - success=[pool_name], - errors=[], + pool_bulk_aresponse = BulkResponse( + create=BulkActionResponse(success=[pool_name], errors=[]), + update=None, + delete=None, ) def test_get(self): @@ -989,11 +989,11 @@ def handle_request(request: httpx.Request) -> httpx.Response: def test_bulk(self): def handle_request(request: httpx.Request) -> httpx.Response: assert request.url.path == "/api/v2/pools" - return httpx.Response(200, json=json.loads(self.pool_bulk_action_response.model_dump_json())) + return httpx.Response(200, json=json.loads(self.pool_bulk_aresponse.model_dump_json())) client = make_api_client(transport=httpx.MockTransport(handle_request)) response = client.pools.bulk(pools=self.pools_bulk_body) - assert response == self.pool_bulk_action_response + assert response == self.pool_bulk_aresponse def test_delete(self): def handle_request(request: httpx.Request) -> httpx.Response: diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py index b5644d6627fb0..44cc341ee1c36 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py @@ -81,7 +81,11 @@ def test_import_success(self, mock_client, tmp_path, capsys): mock_response = mock.MagicMock() mock_response.success = ["test_pool"] mock_response.errors = [] - mock_client.pools.bulk.return_value = mock_response + + mock_bulk_builder = mock.MagicMock() + mock_bulk_builder.create = mock_response + + mock_client.pools.bulk.return_value = mock_bulk_builder pool_command.import_(args=mock.MagicMock(file=pools_file))