Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage] Adjust return type of StorageStreamDownloader.readall #25174

Merged
merged 5 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/storage/azure-storage-blob/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features Added

### Bugs Fixed
- Adjusted type hints for `upload_blob` and `StorageStreamDownloader.readall`.

## 12.13.0 (2022-07-07)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from io import BytesIO
from typing import (
Any, AnyStr, Dict, IO, Iterable, List, Optional, Tuple, Type, TypeVar, Union,
Any, AnyStr, Dict, IO, Iterable, List, Optional, overload, Tuple, Type, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote
Expand Down Expand Up @@ -733,8 +733,8 @@ def upload_blob( # pylint: disable=too-many-locals
return upload_page_blob(**options)
return upload_append_blob(**options)

def _download_blob_options(self, offset=None, length=None, **kwargs):
# type: (Optional[int], Optional[int], **Any) -> Dict[str, Any]
def _download_blob_options(self, offset=None, length=None, encoding=None, **kwargs):
# type: (Optional[int], Optional[int], Optional[str], **Any) -> Dict[str, Any]
if self.require_encryption and not self.key_encryption_key:
raise ValueError("Encryption required but no key was provided.")
if length is not None and offset is None:
Expand Down Expand Up @@ -768,18 +768,40 @@ def _download_blob_options(self, offset=None, length=None, **kwargs):
'lease_access_conditions': access_conditions,
'modified_access_conditions': mod_conditions,
'cpk_info': cpk_info,
'cls': kwargs.pop('cls', None) or deserialize_blob_stream,
'download_cls': kwargs.pop('cls', None) or deserialize_blob_stream,
'max_concurrency':kwargs.pop('max_concurrency', 1),
'encoding': kwargs.pop('encoding', None),
'encoding': encoding,
'timeout': kwargs.pop('timeout', None),
'name': self.blob_name,
'container': self.container_name}
options.update(kwargs)
return options

@overload
def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace
def download_blob(self, offset=None, length=None, **kwargs):
# type: (Optional[int], Optional[int], **Any) -> StorageStreamDownloader
def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -867,6 +889,7 @@ def download_blob(self, offset=None, length=None, **kwargs):
options = self._download_blob_options(
offset=offset,
length=length,
encoding=encoding,
**kwargs)
return StorageStreamDownloader(**options)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import functools
from typing import ( # pylint: disable=unused-import
Any, AnyStr, Dict, List, IO, Iterable, Iterator, Optional, TypeVar, Union,
Any, AnyStr, Dict, List, IO, Iterable, Iterator, Optional, overload, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote
Expand All @@ -31,6 +31,7 @@
from ._generated.models import SignedIdentifier
from ._blob_client import BlobClient
from ._deserialize import deserialize_container_properties
from ._download import StorageStreamDownloader
from ._encryption import StorageEncryptionMixin
from ._lease import BlobLeaseClient
from ._list_blobs_helper import BlobPrefix, BlobPropertiesPaged, FilteredBlobPaged
Expand Down Expand Up @@ -1060,9 +1061,34 @@ def delete_blob(
timeout=timeout,
**kwargs)

@overload
def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace
def download_blob(self, blob, offset=None, length=None, **kwargs):
# type: (Union[str, BlobProperties], Optional[int], Optional[int], **Any) -> StorageStreamDownloader
def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -1143,7 +1169,11 @@ def download_blob(self, blob, offset=None, length=None, **kwargs):
"""
blob_client = self.get_blob_client(blob) # type: ignore
kwargs.setdefault('merge_span', True)
return blob_client.download_blob(offset=offset, length=length, **kwargs)
return blob_client.download_blob(
offset=offset,
length=length,
encoding=encoding,
**kwargs)

def _generate_delete_blobs_subrequest_options(
self, snapshot=None,
Expand Down
15 changes: 11 additions & 4 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import warnings
from io import BytesIO
from typing import Iterator, Union
from typing import Generic, Iterator, TypeVar

import requests
from azure.core.exceptions import HttpResponseError, ServiceResponseError
Expand All @@ -26,6 +26,8 @@
parse_encryption_data
)

T = TypeVar('T', bytes, str)


def process_range_and_offset(start_range, end_range, length, encryption_options, encryption_data):
start_offset, end_offset = 0, 0
Expand Down Expand Up @@ -281,7 +283,7 @@ def _get_chunk_data(self):
return chunk_data


class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes
class StorageStreamDownloader(Generic[T]): # pylint: disable=too-many-instance-attributes
"""A streaming object to download from Azure Storage.

:ivar str name:
Expand All @@ -308,6 +310,7 @@ def __init__(
name=None,
container=None,
encoding=None,
download_cls=None,
**kwargs
):
self.name = name
Expand All @@ -333,6 +336,10 @@ def __init__(
self._response = None
self._encryption_data = None

# The cls is passed in via download_cls to avoid conflicting arg name with Generic.__new__
# but needs to be changed to cls in the request options.
self._request_options['cls'] = download_cls

if self._encryption_options.get("key") is not None or self._encryption_options.get("resolver") is not None:
self._get_encryption_data_request()

Expand Down Expand Up @@ -546,11 +553,11 @@ def chunks(self):
chunk_size=self._config.max_chunk_get_size)

def readall(self):
# type: () -> Union[bytes, str]
# type: () -> T
"""Download the contents of this blob.

This operation is blocking until all data is downloaded.
:rtype: bytes or str
:rtype: T
"""
stream = BytesIO()
self.readinto(stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from functools import partial
from typing import ( # pylint: disable=unused-import
Any, AnyStr, Dict, IO, Iterable, List, Tuple, Optional, Union,
Any, AnyStr, Dict, IO, Iterable, List, Optional, overload, Tuple, Union,
TYPE_CHECKING
)

Expand Down Expand Up @@ -405,9 +405,31 @@ async def upload_blob(
return await upload_page_blob(**options)
return await upload_append_blob(**options)

@overload
async def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
async def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace_async
async def download_blob(self, offset=None, length=None, **kwargs):
# type: (Optional[int], Optional[int], Any) -> StorageStreamDownloader
async def download_blob(
self, offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an async iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -495,6 +517,7 @@ async def download_blob(self, offset=None, length=None, **kwargs):
options = self._download_blob_options(
offset=offset,
length=length,
encoding=encoding,
**kwargs)
downloader = StorageStreamDownloader(**options)
await downloader._setup() # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import functools
from typing import ( # pylint: disable=unused-import
Any, AnyStr, AsyncIterator, Dict, List, IO, Iterable, Optional, Union,
Any, AnyStr, AsyncIterator, Dict, List, IO, Iterable, Optional, overload, Union,
TYPE_CHECKING
)

Expand All @@ -30,6 +30,7 @@
from .._generated.models import SignedIdentifier
from .._container_client import ContainerClient as ContainerClientBase, _get_blob_name
from .._deserialize import deserialize_container_properties
from ._download_async import StorageStreamDownloader
from .._encryption import StorageEncryptionMixin
from .._models import ContainerProperties, BlobType, BlobProperties, FilteredBlob
from .._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions
Expand All @@ -40,7 +41,6 @@

if TYPE_CHECKING:
from datetime import datetime
from ._download_async import StorageStreamDownloader
from .._models import ( # pylint: disable=unused-import
AccessPolicy,
StandardBlobTier,
Expand Down Expand Up @@ -928,9 +928,34 @@ async def delete_blob(
timeout=timeout,
**kwargs)

@overload
async def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: str,
**kwargs) -> StorageStreamDownloader[str]:
...

@overload
async def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: None = None,
**kwargs) -> StorageStreamDownloader[bytes]:
...

@distributed_trace_async
async def download_blob(self, blob, offset=None, length=None, **kwargs):
# type: (Union[str, BlobProperties], Optional[int], Optional[int], Any) -> StorageStreamDownloader
async def download_blob(
self, blob: Union[str, BlobProperties],
offset: int = None,
length: int = None,
*,
encoding: Optional[str] = None,
**kwargs) -> StorageStreamDownloader:
"""Downloads a blob to the StorageStreamDownloader. The readall() method must
be used to read all the content or readinto() must be used to download the blob into
a stream. Using chunks() returns an async iterator which allows the user to iterate over the content in chunks.
Expand Down Expand Up @@ -1014,6 +1039,7 @@ async def download_blob(self, blob, offset=None, length=None, **kwargs):
return await blob_client.download_blob(
offset=offset,
length=length,
encoding=encoding,
**kwargs)

@distributed_trace_async
Expand Down
Loading