From 77f111faf6a640566bb8eaefb0a790e9a411142e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 10 Feb 2024 23:04:36 +0800 Subject: [PATCH] fix(providers/google): fix how GKEPodAsyncHook.service_file_as_context is used --- .../google/cloud/hooks/kubernetes_engine.py | 6 +++--- .../google/cloud/hooks/test_kubernetes_engine.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 7f74447827da..fd44a1dfa3c8 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -508,7 +508,7 @@ async def get_pod(self, name: str, namespace: str) -> V1Pod: :param name: Name of the pod. :param namespace: Name of the pod's namespace. """ - async with self.service_file_as_context() as service_file: # type: ignore[attr-defined] + with await self.service_file_as_context() as service_file: # type: ignore[attr-defined] async with Token(scopes=self.scopes, service_file=service_file) as token: async with self.get_conn(token) as connection: v1_api = async_client.CoreV1Api(connection) @@ -524,7 +524,7 @@ async def delete_pod(self, name: str, namespace: str): :param name: Name of the pod. :param namespace: Name of the pod's namespace. """ - async with self.service_file_as_context() as service_file: # type: ignore[attr-defined] + with await self.service_file_as_context() as service_file: # type: ignore[attr-defined] async with Token(scopes=self.scopes, service_file=service_file) as token, self.get_conn( token ) as connection: @@ -551,7 +551,7 @@ async def read_logs(self, name: str, namespace: str): :param name: Name of the pod. :param namespace: Name of the pod's namespace. """ - async with self.service_file_as_context() as service_file: # type: ignore[attr-defined] + with await self.service_file_as_context() as service_file: # type: ignore[attr-defined] async with Token(scopes=self.scopes, service_file=service_file) as token, self.get_conn( token ) as connection: diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py index e52cfae870b8..fae3db76e992 100644 --- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py @@ -324,8 +324,8 @@ def async_hook(self): async def test_get_pod( self, read_namespace_pod_mock, get_conn_mock, mock_token, async_hook, mock_service_file ): - async_hook.service_file_as_context = mock.MagicMock() - async_hook.service_file_as_context.return_value.__aenter__.return_value = mock_service_file + async_hook.service_file_as_context = mock.AsyncMock() + async_hook.service_file_as_context.return_value.__enter__.return_value = mock_service_file self.make_mock_awaitable(read_namespace_pod_mock) @@ -347,8 +347,8 @@ async def test_get_pod( async def test_delete_pod( self, delete_namespaced_pod, get_conn_mock, mock_token, async_hook, mock_service_file ): - async_hook.service_file_as_context = mock.MagicMock() - async_hook.service_file_as_context.return_value.__aenter__.return_value = mock_service_file + async_hook.service_file_as_context = mock.AsyncMock() + async_hook.service_file_as_context.return_value.__enter__.return_value = mock_service_file self.make_mock_awaitable(delete_namespaced_pod) @@ -372,8 +372,8 @@ async def test_delete_pod( async def test_read_logs( self, read_namespaced_pod_log, get_conn_mock, mock_token, async_hook, mock_service_file, caplog ): - async_hook.service_file_as_context = mock.MagicMock() - async_hook.service_file_as_context.return_value.__aenter__.return_value = mock_service_file + async_hook.service_file_as_context = mock.AsyncMock() + async_hook.service_file_as_context.return_value.__enter__.return_value = mock_service_file self.make_mock_awaitable(read_namespaced_pod_log, result="Test string #1\nTest string #2\n")