Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions airflow-ctl/src/airflowctl/api/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
BackfillCollectionResponse,
BackfillPostBody,
BackfillResponse,
BulkActionResponse,
BulkBodyConnectionBody,
BulkBodyPoolBody,
BulkBodyVariableBody,
Expand Down Expand Up @@ -382,11 +381,11 @@ def create(
except ServerResponseError as e:
raise e

def bulk(self, connections: BulkBodyConnectionBody) -> BulkActionResponse | ServerResponseError:
def bulk(self, connections: BulkBodyConnectionBody) -> BulkResponse | ServerResponseError:
"""CRUD multiple connections."""
try:
self.response = self.client.patch("connections", json=connections.model_dump())
return BulkActionResponse.model_validate_json(self.response.content)
return BulkResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
raise e

Expand Down Expand Up @@ -609,11 +608,11 @@ def create(self, pool: PoolBody) -> PoolResponse | ServerResponseError:
except ServerResponseError as e:
raise e

def bulk(self, pools: BulkBodyPoolBody) -> BulkActionResponse | ServerResponseError:
def bulk(self, pools: BulkBodyPoolBody) -> BulkResponse | ServerResponseError:
"""CRUD multiple pools."""
try:
self.response = self.client.patch("pools", json=pools.model_dump())
return BulkActionResponse.model_validate_json(self.response.content)
return BulkResponse.model_validate_json(self.response.content)
except ServerResponseError as e:
raise e

Expand Down
12 changes: 6 additions & 6 deletions airflow-ctl/src/airflowctl/ctl/commands/pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@


@provide_api_client(kind=ClientKind.CLI)
def import_(args, api_client: Client = NEW_API_CLIENT):
def import_(args, api_client: Client = NEW_API_CLIENT) -> None:
"""Import pools from file."""
filepath = Path(args.file)
if not filepath.exists():
raise SystemExit(f"Missing pools file {args.file}")

success, failed = _import_helper(api_client, filepath)
if failed:
raise SystemExit(f"Failed to update pool(s): {', '.join(failed)}")
success, errors = _import_helper(api_client, filepath)
if errors:
raise SystemExit(f"Failed to update pool(s): {errors}")
rich.print(success)


@provide_api_client(kind=ClientKind.CLI)
def export(args, api_client: Client = NEW_API_CLIENT):
def export(args, api_client: Client = NEW_API_CLIENT) -> None:
"""
Export all pools.

Expand Down Expand Up @@ -119,4 +119,4 @@ def _import_helper(api_client: Client, filepath: Path):
)
result = api_client.pools.bulk(pools=bulk_body)
# Return the successful and failed entities directly from the response
return result.success, result.errors
return result.create.success, result.create.errors
24 changes: 12 additions & 12 deletions airflow-ctl/tests/airflow_ctl/api/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,10 @@ class TestConnectionsOperations:
]
)

connection_bulk_action_response = BulkActionResponse(
success=[connection_id],
errors=[],
connection_bulk_response = BulkResponse(
create=BulkActionResponse(success=[connection_id], errors=[]),
update=None,
delete=None,
)

def test_get(self):
Expand Down Expand Up @@ -522,13 +523,11 @@ def handle_request(request: httpx.Request) -> httpx.Response:
def test_bulk(self):
def handle_request(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/api/v2/connections"
return httpx.Response(
200, json=json.loads(self.connection_bulk_action_response.model_dump_json())
)
return httpx.Response(200, json=json.loads(self.connection_bulk_response.model_dump_json()))

client = make_api_client(transport=httpx.MockTransport(handle_request))
response = client.connections.bulk(connections=self.connection_bulk_body)
assert response == self.connection_bulk_action_response
assert response == self.connection_bulk_response

def test_delete(self):
def handle_request(request: httpx.Request) -> httpx.Response:
Expand Down Expand Up @@ -954,9 +953,10 @@ class TestPoolsOperations:
pools=[pool_response],
total_entries=1,
)
pool_bulk_action_response = BulkActionResponse(
success=[pool_name],
errors=[],
pool_bulk_aresponse = BulkResponse(
create=BulkActionResponse(success=[pool_name], errors=[]),
update=None,
delete=None,
)

def test_get(self):
Expand Down Expand Up @@ -989,11 +989,11 @@ def handle_request(request: httpx.Request) -> httpx.Response:
def test_bulk(self):
def handle_request(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/api/v2/pools"
return httpx.Response(200, json=json.loads(self.pool_bulk_action_response.model_dump_json()))
return httpx.Response(200, json=json.loads(self.pool_bulk_aresponse.model_dump_json()))

client = make_api_client(transport=httpx.MockTransport(handle_request))
response = client.pools.bulk(pools=self.pools_bulk_body)
assert response == self.pool_bulk_action_response
assert response == self.pool_bulk_aresponse

def test_delete(self):
def handle_request(request: httpx.Request) -> httpx.Response:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ def test_import_success(self, mock_client, tmp_path, capsys):
mock_response = mock.MagicMock()
mock_response.success = ["test_pool"]
mock_response.errors = []
mock_client.pools.bulk.return_value = mock_response

mock_bulk_builder = mock.MagicMock()
mock_bulk_builder.create = mock_response

mock_client.pools.bulk.return_value = mock_bulk_builder

pool_command.import_(args=mock.MagicMock(file=pools_file))

Expand Down