Skip to content

Commit

Permalink
Use token-based authentication instead of the access key in the Azure…
Browse files Browse the repository at this point in the history
…FileShareService (#779)

Make `AzureFileShareService` class use token-based credential instead of
the access key. This PR is part of #777

**NOTE:**
* We have to use late initialization of the `_share_client` to avoid
issues with JSON schema validation tests.
* More PRs will follow:
* Use to `azcopy` on the remote VMs instead of mounting the file share
using the access key
* Remove references to `"storageAccountKey"` from configurations and
tests and document the new mechanism

---------

Co-authored-by: Brian Kroth <bpkroth@users.noreply.github.com>
  • Loading branch information
motus and bpkroth authored Jul 24, 2024
1 parent 7dce3d1 commit 73f0708
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 28 deletions.
36 changes: 22 additions & 14 deletions mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from mlos_bench.services.base_fileshare import FileShareService
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.authenticator_type import SupportsAuth
from mlos_bench.util import check_required_params

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,23 +53,30 @@ def __init__(
parent,
self.merge_methods(methods, [self.upload, self.download]),
)

check_required_params(
self.config,
{
"storageAccountName",
"storageFileShareName",
"storageAccountKey",
},
)

self._share_client = ShareClient.from_share_url(
AzureFileShareService._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self.config["storageAccountKey"],
)
self._share_client: Optional[ShareClient] = None

def _get_share_client(self) -> ShareClient:
"""Get the Azure file share client object."""
if self._share_client is None:
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
self._share_client = ShareClient.from_share_url(
self._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self._parent.get_access_token(),
token_intent="backup",
)
return self._share_client

def download(
self,
Expand All @@ -78,7 +86,7 @@ def download(
recursive: bool = True,
) -> None:
super().download(params, remote_path, local_path, recursive)
dir_client = self._share_client.get_directory_client(remote_path)
dir_client = self._get_share_client().get_directory_client(remote_path)
if dir_client.exists():
os.makedirs(local_path, exist_ok=True)
for content in dir_client.list_directories_and_files():
Expand All @@ -91,7 +99,7 @@ def download(
# Ensure parent folders exist
folder, _ = os.path.split(local_path)
os.makedirs(folder, exist_ok=True)
file_client = self._share_client.get_file_client(remote_path)
file_client = self._get_share_client().get_file_client(remote_path)
try:
data = file_client.download_file()
with open(local_path, "wb") as output_file:
Expand Down Expand Up @@ -147,7 +155,7 @@ def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[
# Ensure parent folders exist
folder, _ = os.path.split(remote_path)
self._remote_makedirs(folder)
file_client = self._share_client.get_file_client(remote_path)
file_client = self._get_share_client().get_file_client(remote_path)
with open(local_path, "rb") as file_data:
_LOG.debug("Upload file: %s -> %s", local_path, remote_path)
file_client.upload_file(file_data)
Expand All @@ -167,6 +175,6 @@ def _remote_makedirs(self, remote_path: str) -> None:
if not folder:
continue
path += folder + "/"
dir_client = self._share_client.get_directory_client(path)
dir_client = self._get_share_client().get_directory_client(path)
if not dir_client.exists():
dir_client.create_directory()
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def test_download_file(
local_folder = "some/local/folder"
remote_path = f"{remote_folder}/{filename}"
local_path = f"{local_folder}/{filename}"
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access

config: dict = {}
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object(
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client, patch.object(
mock_share_client, "get_directory_client"
) as mock_get_directory_client:

mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False))

azure_fileshare.download(config, remote_path, local_path)
Expand Down Expand Up @@ -81,8 +84,9 @@ def test_download_folder_non_recursive(
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access

config: dict = {}
with patch.object(
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_directory_client"
) as mock_get_directory_client, patch.object(
mock_share_client, "get_file_client"
Expand Down Expand Up @@ -114,15 +118,14 @@ def test_download_folder_recursive(
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access

config: dict = {}
with patch.object(
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_directory_client"
) as mock_get_directory_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]

azure_fileshare.download(config, remote_folder, local_folder, recursive=True)

mock_get_file_client.assert_has_calls(
Expand Down Expand Up @@ -157,9 +160,11 @@ def test_upload_file(
local_path = f"{local_folder}/{filename}"
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
mock_isdir.return_value = False
config: dict = {}

with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
config: dict = {}
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
azure_fileshare.upload(config, local_path, remote_path)

mock_get_file_client.assert_called_with(remote_path)
Expand Down Expand Up @@ -228,9 +233,11 @@ def test_upload_directory_non_recursive(
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}

with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
config: dict = {}
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
azure_fileshare.upload(config, local_folder, remote_folder, recursive=False)

mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv")
Expand All @@ -252,9 +259,11 @@ def test_upload_directory_recursive(
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}

with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
config: dict = {}
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
azure_fileshare.upload(config, local_folder, remote_folder, recursive=True)

mock_get_file_client.assert_has_calls(
Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> A


@pytest.fixture
def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService:
def azure_fileshare(azure_auth_service: AzureAuthService) -> AzureFileShareService:
"""Creates a dummy AzureFileShareService for tests that require it."""
with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"):
return AzureFileShareService(
Expand All @@ -112,5 +112,5 @@ def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> Azu
"storageAccountKey": "TEST_ACCOUNT_KEY",
},
global_config={},
parent=config_persistence_service,
parent=azure_auth_service,
)

0 comments on commit 73f0708

Please sign in to comment.