diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index d80ea862c9..5ff1b638a3 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -13,6 +13,7 @@ from mlos_bench.services.base_fileshare import FileShareService from mlos_bench.services.base_service import Service +from mlos_bench.services.types.authenticator_type import SupportsAuth from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) @@ -52,23 +53,30 @@ def __init__( parent, self.merge_methods(methods, [self.upload, self.download]), ) - check_required_params( self.config, { "storageAccountName", "storageFileShareName", - "storageAccountKey", }, ) - - self._share_client = ShareClient.from_share_url( - AzureFileShareService._SHARE_URL.format( - account_name=self.config["storageAccountName"], - fs_name=self.config["storageFileShareName"], - ), - credential=self.config["storageAccountKey"], - ) + self._share_client: Optional[ShareClient] = None + + def _get_share_client(self) -> ShareClient: + """Get the Azure file share client object.""" + if self._share_client is None: + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" + self._share_client = ShareClient.from_share_url( + self._SHARE_URL.format( + account_name=self.config["storageAccountName"], + fs_name=self.config["storageFileShareName"], + ), + credential=self._parent.get_access_token(), + token_intent="backup", + ) + return self._share_client def download( self, @@ -78,7 +86,7 @@ def download( recursive: bool = True, ) -> None: super().download(params, remote_path, local_path, recursive) - dir_client = self._share_client.get_directory_client(remote_path) + dir_client = self._get_share_client().get_directory_client(remote_path) if dir_client.exists(): os.makedirs(local_path, exist_ok=True) for content in dir_client.list_directories_and_files(): @@ -91,7 +99,7 @@ def download( # Ensure parent folders exist folder, _ = os.path.split(local_path) os.makedirs(folder, exist_ok=True) - file_client = self._share_client.get_file_client(remote_path) + file_client = self._get_share_client().get_file_client(remote_path) try: data = file_client.download_file() with open(local_path, "wb") as output_file: @@ -147,7 +155,7 @@ def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[ # Ensure parent folders exist folder, _ = os.path.split(remote_path) self._remote_makedirs(folder) - file_client = self._share_client.get_file_client(remote_path) + file_client = self._get_share_client().get_file_client(remote_path) with open(local_path, "rb") as file_data: _LOG.debug("Upload file: %s -> %s", local_path, remote_path) file_client.upload_file(file_data) @@ -167,6 +175,6 @@ def _remote_makedirs(self, remote_path: str) -> None: if not folder: continue path += folder + "/" - dir_client = self._share_client.get_directory_client(path) + dir_client = self._get_share_client().get_directory_client(path) if not dir_client.exists(): dir_client.create_directory() diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index 79090a2f5f..d2aecc2275 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -26,11 +26,14 @@ def test_download_file( local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object( + with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client, patch.object( mock_share_client, "get_directory_client" ) as mock_get_directory_client: + mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) azure_fileshare.download(config, remote_path, local_path) @@ -81,8 +84,9 @@ def test_download_folder_non_recursive( local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + config: dict = {} - with patch.object( + with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object( mock_share_client, "get_directory_client" ) as mock_get_directory_client, patch.object( mock_share_client, "get_file_client" @@ -114,15 +118,14 @@ def test_download_folder_recursive( remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + config: dict = {} - with patch.object( + with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object( mock_share_client, "get_directory_client" ) as mock_get_directory_client, patch.object( mock_share_client, "get_file_client" ) as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] - azure_fileshare.download(config, remote_folder, local_folder, recursive=True) mock_get_file_client.assert_has_calls( @@ -157,9 +160,11 @@ def test_upload_file( local_path = f"{local_folder}/{filename}" mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access mock_isdir.return_value = False - config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + config: dict = {} + with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: azure_fileshare.upload(config, local_path, remote_path) mock_get_file_client.assert_called_with(remote_path) @@ -228,9 +233,11 @@ def test_upload_directory_non_recursive( mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access - config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + config: dict = {} + with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: azure_fileshare.upload(config, local_folder, remote_folder, recursive=False) mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv") @@ -252,9 +259,11 @@ def test_upload_directory_recursive( mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access - config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + config: dict = {} + with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: azure_fileshare.upload(config, local_folder, remote_folder, recursive=True) mock_get_file_client.assert_has_calls( diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index ad7bae26ee..37d554c897 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -102,7 +102,7 @@ def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> A @pytest.fixture -def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService: +def azure_fileshare(azure_auth_service: AzureAuthService) -> AzureFileShareService: """Creates a dummy AzureFileShareService for tests that require it.""" with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"): return AzureFileShareService( @@ -112,5 +112,5 @@ def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> Azu "storageAccountKey": "TEST_ACCOUNT_KEY", }, global_config={}, - parent=config_persistence_service, + parent=azure_auth_service, )