Skip to content

Commit

Permalink
[Storage] az storage blob copy start/start-batch: Fix `--auth-mode …
Browse files Browse the repository at this point in the history
…login` (#29964)

* fix `az storage blob copy start` `--auth-mode login`

* file sas does not allow user-delegation-key, style

* rerun all copy start and start-batch tests

* lint
  • Loading branch information
calvinhzy authored Oct 9, 2024
1 parent 6bbf23a commit e7f2024
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 68 deletions.
54 changes: 41 additions & 13 deletions src/azure-cli/azure/cli/command_modules/storage/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def process_blob_source_uri(cmd, namespace):
if not sas:
prefix = cmd.command_kwargs['resource_type'].value[0]
if is_storagev2(prefix):
sas = create_short_lived_blob_sas_v2(cmd, source_account_name, source_account_key, container,
blob)
sas = create_short_lived_blob_sas_v2(cmd, source_account_name, container,
blob, account_key=source_account_key)
else:
sas = create_short_lived_blob_sas(cmd, source_account_name, source_account_key, container, blob)
query_params = []
Expand Down Expand Up @@ -409,8 +409,8 @@ def validate_source_uri(cmd, namespace): # pylint: disable=too-many-statements
dir_name, file_name)
elif valid_blob_source and (ns.get('share_name', None) or not same_account):
if is_storagev2(prefix):
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, source_account_key, container,
blob)
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, container,
blob, account_key=source_account_key)
else:
source_sas = create_short_lived_blob_sas(cmd, source_account_name, source_account_key, container, blob)

Expand All @@ -435,7 +435,8 @@ def validate_source_uri(cmd, namespace): # pylint: disable=too-many-statements


def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements, too-many-locals
from .util import create_short_lived_blob_sas, create_short_lived_blob_sas_v2, create_short_lived_file_sas
from .util import create_short_lived_blob_sas, create_short_lived_blob_sas_v2, create_short_lived_file_sas, \
create_short_lived_file_sas_v2
from azure.cli.core.azclierror import InvalidArgumentValueError, RequiredArgumentMissingError, \
MutuallyExclusiveArgumentError
usage_string = \
Expand Down Expand Up @@ -463,6 +464,8 @@ def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements,
source_account_name = ns.pop('source_account_name', None)
source_account_key = ns.pop('source_account_key', None)
source_sas = ns.pop('source_sas', None)
token_credential = ns.get('token_credential')
is_oauth = token_credential is not None

# source in the form of an uri
uri = ns.get('source_url', None)
Expand Down Expand Up @@ -499,7 +502,7 @@ def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements,
# determine if the copy will happen in the same storage account
same_account = False

if not source_account_key and not source_sas:
if not source_account_key and not source_sas and not is_oauth:
if source_account_name == ns.get('account_name', None):
same_account = True
source_account_key = ns.get('account_key', None)
Expand All @@ -511,20 +514,41 @@ def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements,
except ValueError:
raise RequiredArgumentMissingError('Source storage account {} not found.'.format(source_account_name))

# if oauth, use user delegation key to generate sas
source_user_delegation_key = None
if is_oauth:
client_kwargs = {'account_name': source_account_name,
'token_credential': token_credential}
if valid_blob_source:
client = cf_blob_service(cmd.cli_ctx, client_kwargs)

from datetime import datetime, timedelta
start = datetime.utcnow()
expiry = datetime.utcnow() + timedelta(days=1)
source_user_delegation_key = client.get_user_delegation_key(start, expiry)

# Both source account name and either key or sas (or both) are now available
if not source_sas:
prefix = cmd.command_kwargs['resource_type'].value[0]
# generate a sas token even in the same account when the source and destination are not the same kind.
if valid_file_source and (ns.get('container_name', None) or not same_account):
dir_name, file_name = os.path.split(path) if path else (None, '')
source_sas = create_short_lived_file_sas(cmd, source_account_name, source_account_key, share,
dir_name, file_name)
if dir_name == '':
dir_name = None
if is_storagev2(prefix):
source_sas = create_short_lived_file_sas_v2(cmd, source_account_name, source_account_key, share,
dir_name, file_name)
else:
source_sas = create_short_lived_file_sas(cmd, source_account_name, source_account_key, share,
dir_name, file_name)
elif valid_blob_source and (ns.get('share_name', None) or not same_account):
prefix = cmd.command_kwargs['resource_type'].value[0]
# is_storagev2() is used to distinguish if the command is in track2 SDK
# If yes, we will use get_login_credentials() as token credential
if is_storagev2(prefix):
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, source_account_key, container,
blob)
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, container, blob,
account_key=source_account_key,
user_delegation_key=source_user_delegation_key)
else:
source_sas = create_short_lived_blob_sas(cmd, source_account_name, source_account_key, container, blob)

Expand Down Expand Up @@ -1069,6 +1093,8 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):
source_sas = ns.get('source_sas', None)
source_container = ns.get('source_container', None)
source_share = ns.get('source_share', None)
token_credential = ns.get('token_credential')
is_oauth = token_credential is not None

if source_uri and source_account:
raise ValueError(usage_string)
Expand All @@ -1090,13 +1116,13 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):

source_account, source_key, source_sas = ns['account_name'], ns['account_key'], ns['sas_token']

if source_account:
if source_account and not is_oauth:
if not (source_key or source_sas):
# when neither storage account key nor SAS is given, try to fetch the key in the current
# subscription
source_key = _query_account_key(cmd.cli_ctx, source_account)

elif source_uri:
elif source_uri and not is_oauth:
if source_key or source_container or source_share:
raise ValueError(usage_string)

Expand Down Expand Up @@ -1125,7 +1151,7 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):
ns['source_container'] = source_container
ns['source_share'] = source_share
# get sas token for source
if not source_sas:
if not source_sas and not is_oauth:
from .util import create_short_lived_container_sas_track2, create_short_lived_share_sas_track2
if source_container:
source_sas = create_short_lived_container_sas_track2(cmd, account_name=source_account,
Expand All @@ -1139,6 +1165,8 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):
client_kwargs = {'account_name': ns['source_account_name'],
'account_key': ns['source_account_key'],
'sas_token': ns['source_sas']}
if is_oauth:
client_kwargs.update({'token_credential': token_credential})
if source_container:
ns['source_client'] = cf_blob_service(cmd.cli_ctx, client_kwargs)
if source_share:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,14 @@ def create_blob_url(client, container_name, blob_name, snapshot, protocol='https
def _copy_blob_to_blob_container(cmd, blob_service, source_blob_service, destination_container, destination_path,
source_container, source_blob_name, source_sas, **kwargs):
t_blob_client = cmd.get_models('_blob_client#BlobClient')
# generate sas for oauth copy source
if not source_sas:
from ..util import create_short_lived_blob_sas_v2
start = datetime.utcnow()
expiry = datetime.utcnow() + timedelta(hours=1)
source_user_delegation_key = source_blob_service.get_user_delegation_key(start, expiry)
source_sas = create_short_lived_blob_sas_v2(cmd, source_blob_service.account_name, source_container,
source_blob_name, user_delegation_key=source_user_delegation_key)
source_client = t_blob_client(account_url=source_blob_service.url, container_name=source_container,
blob_name=source_blob_name, credential=source_sas)
source_blob_url = source_client.url
Expand All @@ -931,7 +939,10 @@ def _copy_file_to_blob_container(blob_service, source_file_service, destination_
source_share, source_sas, source_file_dir, source_file_name):
t_share_client = source_file_service.get_share_client(source_share)
t_file_client = t_share_client.get_file_client(os.path.join(source_file_dir, source_file_name))
source_file_url = '{}?{}'.format(t_file_client.url, source_sas)
if '?' not in t_file_client.url:
source_file_url = '{}?{}'.format(t_file_client.url, source_sas)
else:
source_file_url = t_file_client.url

source_path = os.path.join(source_file_dir, source_file_name) if source_file_dir else source_file_name
destination_blob_name = normalize_blob_file_path(destination_path, source_path)
Expand Down
Loading

0 comments on commit e7f2024

Please sign in to comment.