Skip to content

Commit

Permalink
Fix mypy issues for google
Browse files Browse the repository at this point in the history
  • Loading branch information
rajaths010494 authored and kaxil committed Mar 8, 2022
1 parent d04d772 commit 93f848d
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
17 changes: 9 additions & 8 deletions astronomer/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
This module contains a BigQueryHookAsync
"""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

from aiohttp import ClientSession as Session
from aiohttp import ClientSession as ClientSession
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, _bq_cast
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from gcloud.aio.bigquery import Job
from google.cloud.bigquery import CopyJob, ExtractJob, LoadJob, QueryJob
from requests import Session

from astronomer.providers.google.common.hooks.base_google import GoogleBaseHookAsync

Expand Down Expand Up @@ -81,11 +82,11 @@ class BigQueryHookAsync(GoogleBaseHookAsync):
sync_hook_class = _BigQueryHook

async def get_job_instance(
self, project_id: Optional[str], job_id: Optional[str], session: Session
self, project_id: Optional[str], job_id: Optional[str], session: ClientSession
) -> Job:
"""Get the specified job resource by job ID and project ID."""
with await self.service_file_as_context() as f:
return Job(job_id=job_id, project=project_id, service_file=f, session=session)
return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session))

async def get_job_status(
self,
Expand All @@ -95,11 +96,11 @@ async def get_job_status(
"""Polls for job status asynchronously using gcloud-aio.
Note that an OSError is raised when Job results are still pending.
Exception means that Job finished with errors"""
async with Session() as s:
async with ClientSession() as s:
try:
self.log.info("Executing get_job_status...")
job_client = await self.get_job_instance(project_id, job_id, s)
job_status_response = await job_client.result(s)
job_status_response = await job_client.result(cast(Session, s))
if job_status_response:
job_status = "success"
except OSError:
Expand All @@ -118,10 +119,10 @@ async def get_job_output(
Get the big query job output for the given job id
asynchronously using gcloud-aio.
"""
async with Session() as session:
async with ClientSession() as session:
self.log.info("Executing get_job_output..")
job_client = await self.get_job_instance(project_id, job_id, session)
job_query_response = await job_client.get_query_results(session)
job_query_response = await job_client.get_query_results(cast(Session, session))
return job_query_response

def get_records(self, query_results: Dict[str, Any], nocast: bool = True) -> List[Any]:
Expand Down
8 changes: 5 additions & 3 deletions astronomer/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""This module contains a Google Cloud Storage hook."""
from typing import cast

from aiohttp import ClientSession as Session
from aiohttp import ClientSession as ClientSession
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from gcloud.aio.storage import Storage
from requests import Session

from astronomer.providers.google.common.hooks.base_google import GoogleBaseHookAsync

Expand All @@ -12,9 +14,9 @@
class GCSHookAsync(GoogleBaseHookAsync):
sync_hook_class = GCSHook

async def get_storage_client(self, session: Session) -> Storage:
async def get_storage_client(self, session: ClientSession) -> Storage:
"""
Returns a Google Cloud Storage service object.
"""
with await self.service_file_as_context() as file:
return Storage(service_file=file, session=session)
return Storage(service_file=file, session=cast(Session, session))
3 changes: 2 additions & 1 deletion astronomer/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,12 @@ def execute(self, context: Context) -> None: # type: ignore[override]
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any:
if event["status"] == "error":
raise AirflowException(event["message"])

self.log.info("Total extracted rows: %s", len(event["records"]))
return event["records"]


class BigQueryIntervalCheckOperatorAsync(BigQueryIntervalCheckOperator):
Expand Down
4 changes: 2 additions & 2 deletions astronomer/providers/google/cloud/triggers/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime
from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple

from aiohttp import ClientSession as Session
from aiohttp import ClientSession as ClientSession
from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.google.cloud.hooks.gcs import GCSHookAsync
Expand Down Expand Up @@ -79,7 +79,7 @@ async def _object_exists(self, hook: GCSHookAsync, bucket_name: str, object_name
:param object_name: The name of the blob_name to check in the Google cloud
storage bucket.
"""
async with Session() as s:
async with ClientSession() as s:
client = await hook.get_storage_client(s)
bucket = client.get_bucket(bucket_name)
object_response = await bucket.blob_exists(blob_name=object_name)
Expand Down
6 changes: 3 additions & 3 deletions tests/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_insert_job(self, mock_client, mock_query_job, nowait):


@pytest.mark.asyncio
@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.Session")
@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.ClientSession")
async def test_get_job_instance(mock_session):
hook = BigQueryHookAsync()
result = await hook.get_job_instance(project_id=PROJECT_ID, job_id=JOB_ID, session=mock_session)
Expand All @@ -80,7 +80,7 @@ async def test_get_job_instance(mock_session):

@pytest.mark.asyncio
@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.BigQueryHookAsync.get_job_instance")
@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.Session")
@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.ClientSession")
async def test_get_job_status_success(mock_session, mock_job_instance):
hook = BigQueryHookAsync()
resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
Expand Down Expand Up @@ -111,7 +111,7 @@ async def test_get_job_status_exception(mock_job_instance, caplog):


@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.BigQueryHookAsync.get_job_instance")
@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.Session")
@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.ClientSession")
async def test_get_job_output_assert_once_with(mock_session, mock_job_instance):
hook = BigQueryHookAsync()
await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID)
Expand Down

0 comments on commit 93f848d

Please sign in to comment.