Skip to content

Commit

Permalink
[Storage] Adjust return type of StorageStreamDownloader.readall (#2…
Browse files Browse the repository at this point in the history
  • Loading branch information
jalauzon-msft authored Jul 15, 2022
1 parent e24cf4e commit 8c24354
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 41 deletions.
1 change: 1 addition & 0 deletions sdk/storage/azure-storage-blob/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features Added

### Bugs Fixed
- Adjusted type hints for `upload_blob` and `StorageStreamDownloader.readall`.

## 12.13.0 (2022-07-07)

Expand Down
37 changes: 30 additions & 7 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from io import BytesIO
from typing import (
Any, AnyStr, Dict, IO, Iterable, List, Optional, Tuple, Type, TypeVar, Union,
Any, AnyStr, Dict, IO, Iterable, List, Optional, overload, Tuple, Type, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote
Expand Down Expand Up @@ -733,8 +733,8 @@ def upload_blob( # pylint: disable=too-many-locals
return upload_page_blob(**options)
return upload_append_blob(**options)

def _download_blob_options(self, offset=None, length=None, **kwargs):
# type: (Optional[int], Optional[int], **Any) -> Dict[str, Any]
def _download_blob_options(self, offset=None, length=None, encoding=None, **kwargs):
# type: (Optional[int], Optional[int], Optional[str], **Any) -> Dict[str, Any]
if self.require_encryption and not self.key_encryption_key:
raise ValueError("Encryption required but no key was provided.")
if length is not None and offset is None:
Expand Down Expand Up @@ -768,18 +768,40 @@ def _download_blob_options(self, offset=None, length=None, **kwargs):
'lease_access_conditions': access_conditions,
'modified_access_conditions': mod_conditions,
'cpk_info': cpk_info,
'cls': kwargs.pop('cls', None) or deserialize_blob_stream,
'download_cls': kwargs.pop('cls', None) or deserialize_blob_stream,
'max_concurrency':kwargs.pop('max_concurrency', 1),
'encoding': kwargs.pop('encoding', None),
'encoding': encoding,
'timeout': kwargs.pop('timeout', None),
'name': self.blob_name,
'container': self.container_name}
options.update(kwargs)
return options

@overload
def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace
def download_blob(self, offset=None, length=None, **kwargs):
# type: (Optional[int], Optional[int], **Any) -> StorageStreamDownloader
def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -867,6 +889,7 @@ def download_blob(self, offset=None, length=None, **kwargs):
options = self._download_blob_options(
offset=offset,
length=length,
encoding=encoding,
**kwargs)
return StorageStreamDownloader(**options)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import functools
from typing import ( # pylint: disable=unused-import
Any, AnyStr, Dict, List, IO, Iterable, Iterator, Optional, TypeVar, Union,
Any, AnyStr, Dict, List, IO, Iterable, Iterator, Optional, overload, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote
Expand All @@ -31,6 +31,7 @@
from ._generated.models import SignedIdentifier
from ._blob_client import BlobClient
from ._deserialize import deserialize_container_properties
from ._download import StorageStreamDownloader
from ._encryption import StorageEncryptionMixin
from ._lease import BlobLeaseClient
from ._list_blobs_helper import BlobPrefix, BlobPropertiesPaged, FilteredBlobPaged
Expand Down Expand Up @@ -1060,9 +1061,34 @@ def delete_blob(
timeout=timeout,
**kwargs)

@overload
def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace
def download_blob(self, blob, offset=None, length=None, **kwargs):
# type: (Union[str, BlobProperties], Optional[int], Optional[int], **Any) -> StorageStreamDownloader
def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -1143,7 +1169,11 @@ def download_blob(self, blob, offset=None, length=None, **kwargs):
"""
blob_client = self.get_blob_client(blob) # type: ignore
kwargs.setdefault('merge_span', True)
return blob_client.download_blob(offset=offset, length=length, **kwargs)
return blob_client.download_blob(
offset=offset,
length=length,
encoding=encoding,
**kwargs)

def _generate_delete_blobs_subrequest_options(
self, snapshot=None,
Expand Down
15 changes: 11 additions & 4 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import warnings
from io import BytesIO
from typing import Iterator, Union
from typing import Generic, Iterator, TypeVar

import requests
from azure.core.exceptions import HttpResponseError, ServiceResponseError
Expand All @@ -26,6 +26,8 @@
parse_encryption_data
)

T = TypeVar('T', bytes, str)


def process_range_and_offset(start_range, end_range, length, encryption_options, encryption_data):
start_offset, end_offset = 0, 0
Expand Down Expand Up @@ -281,7 +283,7 @@ def _get_chunk_data(self):
return chunk_data


class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes
class StorageStreamDownloader(Generic[T]): # pylint: disable=too-many-instance-attributes
"""A streaming object to download from Azure Storage.
:ivar str name:
Expand All @@ -308,6 +310,7 @@ def __init__(
name=None,
container=None,
encoding=None,
download_cls=None,
**kwargs
):
self.name = name
Expand All @@ -333,6 +336,10 @@ def __init__(
self._response = None
self._encryption_data = None

# The cls is passed in via download_cls to avoid conflicting arg name with Generic.__new__
# but needs to be changed to cls in the request options.
self._request_options['cls'] = download_cls

if self._encryption_options.get("key") is not None or self._encryption_options.get("resolver") is not None:
self._get_encryption_data_request()

Expand Down Expand Up @@ -546,11 +553,11 @@ def chunks(self):
chunk_size=self._config.max_chunk_get_size)

def readall(self):
# type: () -> Union[bytes, str]
# type: () -> T
"""Download the contents of this blob.
This operation is blocking until all data is downloaded.
:rtype: bytes or str
:rtype: T
"""
stream = BytesIO()
self.readinto(stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from functools import partial
from typing import ( # pylint: disable=unused-import
Any, AnyStr, Dict, IO, Iterable, List, Tuple, Optional, Union,
Any, AnyStr, Dict, IO, Iterable, List, Optional, overload, Tuple, Union,
TYPE_CHECKING
)

Expand Down Expand Up @@ -405,9 +405,31 @@ async def upload_blob(
return await upload_page_blob(**options)
return await upload_append_blob(**options)

@overload
async def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
async def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace_async
async def download_blob(self, offset=None, length=None, **kwargs):
# type: (Optional[int], Optional[int], Any) -> StorageStreamDownloader
async def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an async iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -495,6 +517,7 @@ async def download_blob(self, offset=None, length=None, **kwargs):
options = self._download_blob_options(
offset=offset,
length=length,
encoding=encoding,
**kwargs)
downloader = StorageStreamDownloader(**options)
await downloader._setup() # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import functools
from typing import ( # pylint: disable=unused-import
Any, AnyStr, AsyncIterator, Dict, List, IO, Iterable, Optional, Union,
Any, AnyStr, AsyncIterator, Dict, List, IO, Iterable, Optional, overload, Union,
TYPE_CHECKING
)

Expand All @@ -30,6 +30,7 @@
from .._generated.models import SignedIdentifier
from .._container_client import ContainerClient as ContainerClientBase, _get_blob_name
from .._deserialize import deserialize_container_properties
from ._download_async import StorageStreamDownloader
from .._encryption import StorageEncryptionMixin
from .._models import ContainerProperties, BlobType, BlobProperties, FilteredBlob
from .._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions
Expand All @@ -40,7 +41,6 @@

if TYPE_CHECKING:
from datetime import datetime
from ._download_async import StorageStreamDownloader
from .._models import ( # pylint: disable=unused-import
AccessPolicy,
StandardBlobTier,
Expand Down Expand Up @@ -928,9 +928,34 @@ async def delete_blob(
timeout=timeout,
**kwargs)

@overload
async def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
async def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace_async
async def download_blob(self, blob, offset=None, length=None, **kwargs):
# type: (Union[str, BlobProperties], Optional[int], Optional[int], Any) -> StorageStreamDownloader
async def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an async iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -1014,6 +1039,7 @@ async def download_blob(self, blob, offset=None, length=None, **kwargs):
return await blob_client.download_blob(
offset=offset,
length=length,
encoding=encoding,
**kwargs)

@distributed_trace_async
Expand Down
Loading

0 comments on commit 8c24354

Please sign in to comment.