Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors in Microsoft Azure provider #19923

Merged
merged 6 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/azure/hooks/container_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class AzureContainerInstanceHook(AzureBaseHook):
conn_type = 'azure_container_instance'
hook_name = 'Azure Container Instance'

def __init__(self, *args, **kwargs) -> None:
super().__init__(sdk_client=ContainerInstanceManagementClient, *args, **kwargs)
def __init__(self, conn_id: str = default_conn_name) -> None:
super().__init__(sdk_client=ContainerInstanceManagementClient, conn_id=conn_id)
self.connection = self.get_conn()

def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None:
Expand Down
15 changes: 7 additions & 8 deletions airflow/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_ui_field_behaviour() -> Dict:
},
}

def __init__(self, azure_data_factory_conn_id: Optional[str] = default_conn_name):
def __init__(self, azure_data_factory_conn_id: str = default_conn_name):
self._conn: DataFactoryManagementClient = None
self.conn_id = azure_data_factory_conn_id
super().__init__()
Expand All @@ -144,13 +144,12 @@ def get_conn(self) -> DataFactoryManagementClient:
tenant = conn.extra_dejson.get('extra__azure_data_factory__tenantId')
subscription_id = conn.extra_dejson.get('extra__azure_data_factory__subscriptionId')

credential = None
if conn.login is not None and conn.password is not None:
credential = ClientSecretCredential(
client_id=conn.login, client_secret=conn.password, tenant_id=tenant
client_id=conn.login, client_secret=conn.password, tenant_id=tenant # type: ignore
)
else:
credential = DefaultAzureCredential()
credential = DefaultAzureCredential() # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add credential: Any before the if instead. These credential classes are opaque and internal to Azure, and we are passing them directly back to Azure. This is better than # type: ignore because it is more obvious what kind of errors we are trying to ignore.

self._conn = self._create_client(credential, subscription_id)

return self._conn
Expand Down Expand Up @@ -623,8 +622,8 @@ def wait_for_pipeline_run_status(
expected_statuses: Union[str, Set[str]],
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
check_interval: Optional[int] = 60,
timeout: Optional[int] = 60 * 60 * 24 * 7,
check_interval: int = 60,
timeout: int = 60 * 60 * 24 * 7,
) -> bool:
"""
Waits for a pipeline run to match an expected status.
Expand All @@ -643,7 +642,7 @@ def wait_for_pipeline_run_status(
"factory_name": factory_name,
"resource_group_name": resource_group_name,
}
pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels unnecessary. What is the error we’re trying to work around here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

airflow/providers/microsoft/azure/hooks/data_factory.py:671: error: Argument 1 to "get_pipeline_run_status" of "AzureDataFactoryHook" has incompatible type "**Dict[str, Optional[str]]"; expected
"str"
            pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)

If I enforce keywords arguments on the get_pipeline_run_status() function, mypy doesn't complain. Although, we could always explicitly pass a dictionary rather than using the **kwargs syntax.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I guess I could use a TypedDict here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline_run_info: Dict[str, Any] = resolves the check as well but it's not entirely a correct typing.


start_time = time.monotonic()

Expand All @@ -660,7 +659,7 @@ def wait_for_pipeline_run_status(
# Wait to check the status of the pipeline run based on the ``check_interval`` configured.
time.sleep(check_interval)

pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info)
pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.


return pipeline_run_status in expected_statuses

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_conn(self) -> BlobServiceClient:
app_id = conn.login
app_secret = conn.password
tenant = extra.get('tenant_id') or extra.get('extra__wasb__tenant_id')
token_credential = ClientSecretCredential(tenant, app_id, app_secret)
token_credential = ClientSecretCredential(tenant, app_id, app_secret) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error here was:

airflow/providers/microsoft/azure/hooks/wasb.py:137: error: Argument 1 to "ClientSecretCredential" has incompatible type "Optional[Any]"; expected "str"
                token_credential = ClientSecretCredential(tenant, app_id, app_secret)

I'll address this properly.

return BlobServiceClient(account_url=conn.host, credential=token_credential)
sas_token = extra.get('sas_token') or extra.get('extra__wasb__sas_token')
if sas_token and sas_token.startswith('https'):
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/log/wasb_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from azure.common import AzureHttpError

try:
from functools import cached_property
from functools import cached_property # type: ignore[attr-defined]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. I have to do the same in one of my PRs :).

except ImportError:
from cached_property import cached_property

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/secrets/key_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from azure.keyvault.secrets import SecretClient

try:
from functools import cached_property
from functools import cached_property # type: ignore[attr-defined]
except ImportError:
from cached_property import cached_property

Expand Down
20 changes: 10 additions & 10 deletions tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_create_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
)
def test_update_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._factory_exists = Mock(return_value=True)
hook._factory_exists = Mock(return_value=True) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should find a much better way to have mypy satisfied here.

I think we shoudl (if that works) add spec='boolean' to mock's constructor?

Copy link
Contributor Author

@josh-fell josh-fell Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A patch.object(...) context manager does the trick. WDYT? Mypy still complained adding a spec unfortunately.

hook.update_factory(*user_args)

hook._conn.factories.create_or_update.assert_called_with(*sdk_args)
Expand All @@ -188,7 +188,7 @@ def test_update_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
)
def test_update_factory_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._factory_exists = Mock(return_value=False)
hook._factory_exists = Mock(return_value=False) # type: ignore

with pytest.raises(AirflowException, match=r"Factory .+ does not exist"):
hook.update_factory(*user_args)
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_create_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._linked_service_exists = Mock(return_value=True)
hook._linked_service_exists = Mock(return_value=True) # type: ignore
hook.update_linked_service(*user_args)

hook._conn.linked_services.create_or_update(*sdk_args)
Expand All @@ -240,7 +240,7 @@ def test_update_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_linked_service_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._linked_service_exists = Mock(return_value=False)
hook._linked_service_exists = Mock(return_value=False) # type: ignore

with pytest.raises(AirflowException, match=r"Linked service .+ does not exist"):
hook.update_linked_service(*user_args)
Expand Down Expand Up @@ -281,7 +281,7 @@ def test_create_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._dataset_exists = Mock(return_value=True)
hook._dataset_exists = Mock(return_value=True) # type: ignore
hook.update_dataset(*user_args)

hook._conn.datasets.create_or_update.assert_called_with(*sdk_args)
Expand All @@ -292,7 +292,7 @@ def test_update_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_dataset_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._dataset_exists = Mock(return_value=False)
hook._dataset_exists = Mock(return_value=False) # type: ignore

with pytest.raises(AirflowException, match=r"Dataset .+ does not exist"):
hook.update_dataset(*user_args)
Expand Down Expand Up @@ -333,7 +333,7 @@ def test_create_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._pipeline_exists = Mock(return_value=True)
hook._pipeline_exists = Mock(return_value=True) # type: ignore
hook.update_pipeline(*user_args)

hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args)
Expand All @@ -344,7 +344,7 @@ def test_update_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_pipeline_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._pipeline_exists = Mock(return_value=False)
hook._pipeline_exists = Mock(return_value=False) # type: ignore

with pytest.raises(AirflowException, match=r"Pipeline .+ does not exist"):
hook.update_pipeline(*user_args)
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_create_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._trigger_exists = Mock(return_value=True)
hook._trigger_exists = Mock(return_value=True) # type: ignore
hook.update_trigger(*user_args)

hook._conn.triggers.create_or_update.assert_called_with(*sdk_args)
Expand All @@ -462,7 +462,7 @@ def test_update_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
)
def test_update_trigger_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
hook._trigger_exists = Mock(return_value=False)
hook._trigger_exists = Mock(return_value=False) # type: ignore

with pytest.raises(AirflowException, match=r"Trigger .+ does not exist"):
hook.update_trigger(*user_args)
Expand Down