Skip to content

Commit

Permalink
Fix typing and mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
Eu Jing Chua committed Jul 29, 2024
1 parent 8993318 commit 4a964ee
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 37 deletions.
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from pytz import UTC

from mlos_bench.services.base_service import Service
from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth
from mlos_bench.services.types.authenticator_type import SupportsAuth
from mlos_bench.util import check_required_params

_LOG = logging.getLogger(__name__)


class AzureAuthService(Service, SupportsAzureAuth):
class AzureAuthService(Service, SupportsAuth[azure_cred.TokenCredential]):
"""Helper methods to get access to Azure services."""

_REQ_INTERVAL = 300 # = 5 min
Expand Down
16 changes: 11 additions & 5 deletions mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import os
from typing import Any, Callable, Dict, List, Optional, Set, Union

import azure.core.credentials as azure_cred
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.fileshare import ShareClient

from mlos_bench.services.base_fileshare import FileShareService
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth
from mlos_bench.services.types.authenticator_type import SupportsAuth
from mlos_bench.util import check_required_params

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,20 +61,25 @@ def __init__(
"storageFileShareName",
},
)
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
self._auth_service: SupportsAuth[azure_cred.TokenCredential] = self._parent
self._share_client: Optional[ShareClient] = None

def _get_share_client(self) -> ShareClient:
"""Get the Azure file share client object."""
if self._share_client is None:
assert self._parent is not None and isinstance(
self._parent, SupportsAzureAuth
), "Authorization service not provided. Include service-auth.jsonc?"
credential = self._auth_service.get_credential()
assert isinstance(
credential, azure_cred.TokenCredential
), f"Expected a TokenCredential, but got {type(credential)} instead."
self._share_client = ShareClient.from_share_url(
self._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self._parent.get_credential(),
credential=credential,
token_intent="backup",
)
return self._share_client
Expand Down
2 changes: 0 additions & 2 deletions mlos_bench/mlos_bench/services/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

from mlos_bench.services.types.authenticator_type import SupportsAuth
from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning
Expand All @@ -20,7 +19,6 @@

__all__ = [
"SupportsAuth",
"SupportsAzureAuth",
"SupportsConfigLoading",
"SupportsFileShareOps",
"SupportsHostProvisioning",
Expand Down
16 changes: 14 additions & 2 deletions mlos_bench/mlos_bench/services/types/authenticator_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
#
"""Protocol interface for authentication for the cloud services."""

from typing import Protocol, runtime_checkable
from typing import Protocol, TypeVar, runtime_checkable

T_co = TypeVar("T_co", covariant=True)


@runtime_checkable
class SupportsAuth(Protocol):
class SupportsAuth(Protocol[T_co]):
"""Protocol interface for authentication for the cloud services."""

def get_access_token(self) -> str:
Expand All @@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict:
access_header : dict
HTTP header containing the access token.
"""

def get_credential(self) -> T_co:
"""
Get the credential object for cloud services.
Returns
-------
credential : T
Cloud-specific credential object.
"""
24 changes: 0 additions & 24 deletions mlos_bench/mlos_bench/services/types/azure_authenticator_type.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@ def test_load_service_config_examples(
config_path: str,
) -> None:
"""Tests loading a config example."""
parent: Service = config_loader_service
config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE)
# Add other services that require a SupportsAuth parent service as necessary.
requires_auth_service_parent = {
"AzureFileShareService",
}
config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1]
if config_class_name in requires_auth_service_parent:
# AzureFileShareService requires an auth service to be loaded as well.
auth_service_config = config_loader_service.load_config(
"services/remote/mock/mock_auth_service.jsonc",
ConfigSchema.SERVICE,
)
auth_service = config_loader_service.build_service(
config=auth_service_config,
parent=config_loader_service,
)
parent = auth_service
# Make an instance of the class based on the config.
service_inst = config_loader_service.build_service(
config=config,
parent=config_loader_service,
parent=parent,
)
assert service_inst is not None
assert isinstance(service_inst, Service)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_LOG = logging.getLogger(__name__)


class MockAuthService(Service, SupportsAuth):
class MockAuthService(Service, SupportsAuth[str]):
"""A collection Service functions for mocking authentication ops."""

def __init__(
Expand All @@ -32,6 +32,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand All @@ -41,3 +42,6 @@ def get_access_token(self) -> str:

def get_auth_headers(self) -> dict:
return {"Authorization": "Bearer " + self.get_access_token()}

def get_credential(self) -> str:
return "MOCK CREDENTIAL"

0 comments on commit 4a964ee

Please sign in to comment.