Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LRO continuation_token #10801

Merged
merged 24 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4117b62
LRO continuation_token
lmazuel Apr 10, 2020
bfa829f
from cont token is a clsmethod
lmazuel Apr 10, 2020
9b04974
Add async ABC for cont token
lmazuel Apr 10, 2020
6ed2b7e
Pickle and azure-core pipeline
lmazuel Apr 10, 2020
8aeb891
Make a aiohttp response pickable, but loosing the internal response
lmazuel Apr 13, 2020
2d61d9c
Add AsyncLROPoller
lmazuel Apr 13, 2020
f08c5b1
Merge remote-tracking branch 'origin/master' into lro_continuation_token
lmazuel Apr 21, 2020
ccb627d
mypy
lmazuel Apr 21, 2020
568d68b
mpylint
lmazuel Apr 21, 2020
71f1b8b
Continuation token are optional abstract methods
lmazuel Apr 21, 2020
c9e8de3
Merge remote-tracking branch 'origin/master' into lro_continuation_token
lmazuel Apr 28, 2020
e66c797
Add async status
lmazuel Apr 28, 2020
cc859ea
mypy
lmazuel Apr 28, 2020
98ed170
Merge remote-tracking branch 'origin/master' into lro_continuation_token
lmazuel Apr 29, 2020
52ba1d8
base64 the continuation token to be a string and not bytes
lmazuel Apr 29, 2020
500b2be
Merge remote-tracking branch 'origin/master' into lro_continuation_token
lmazuel May 1, 2020
c4420ec
Typo
lmazuel May 1, 2020
d91cd26
Tests and new AsyncPoller
lmazuel May 1, 2020
02eaf58
Merge remote-tracking branch 'origin/master' into lro_continuation_token
lmazuel May 19, 2020
d6014bd
Fix mypy
lmazuel May 19, 2020
ffb34cb
Fix tests for Python 2.7
lmazuel May 20, 2020
1c76b88
More tests
lmazuel May 20, 2020
fa71ece
Add more tests, including asyncio_ensure_future wrapper
lmazuel May 22, 2020
be43f6c
Merge remote-tracking branch 'origin/master' into lro_continuation_token
lmazuel May 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion sdk/core/azure-core/azure/core/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,21 @@ def __init__(self, transport, **kwargs): # pylint: disable=super-init-not-calle
self.options = kwargs
self._protected = ["transport", "options"]

def __getstate__(self):
state = self.__dict__.copy()
# Remove the unpicklable entries.
del state['transport']
return state

def __setstate__(self, state):
self.__dict__.update(state)
# Re-create the unpickable entries
self.transport = None

def __setitem__(self, key, item):
if key in self._protected:
# If reloaded from pickle, _protected might not be here until restored by pickle
# this explains the hasattr test
if hasattr(self, '_protected') and key in self._protected:
raise ValueError("Context value {} cannot be overwritten.".format(key))
return super(PipelineContext, self).__setitem__(key, item)

Expand Down
11 changes: 11 additions & 0 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,14 @@ def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
:type pipeline: azure.core.pipeline
"""
return AioHttpStreamDownloadGenerator(pipeline, self)

def __getstate__(self):
# Be sure body is loaded in memory, otherwise not pickable and let it throw
self.body()

state = self.__dict__.copy()
# Remove the unpicklable entries.
state['internal_response'] = None # aiohttp response are not pickable (see headers comments)
from multidict import MultiDict # I know it's importable since aiohttp is loaded
state['headers'] = MultiDict(self.headers) # MultiDictProxy is not pickable
return state
annatisch marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions sdk/core/azure-core/azure/core/polling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@
#pylint: disable=unused-import
if sys.version_info >= (3, 5, 2):
# Not executed on old Python, no syntax error
from ._async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller
__all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller']
from ._async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller, AsyncLROPoller
__all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller', 'AsyncLROPoller']
120 changes: 110 additions & 10 deletions sdk/core/azure-core/azure/core/polling/_async_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Generic, TypeVar, Any
from collections.abc import Awaitable
from typing import Callable, Any, Tuple, Generic, TypeVar, Generator

from ._poller import NoPolling as _NoPolling


Expand All @@ -48,6 +50,21 @@ def finished(self) -> bool:
def resource(self) -> PollingReturnType:
raise NotImplementedError("This method needs to be implemented")

def get_continuation_token(self) -> str:
raise TypeError(
"Polling method '{}' doesn't support get_continuation_token".format(
self.__class__.__name__
)
)

@classmethod
def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any, Any, Callable]:
raise TypeError(
"Polling method '{}' doesn't support from_continuation_token".format(
cls.__name__
)
)


class AsyncNoPolling(_NoPolling):
"""An empty async poller that returns the deserialized initial response.
Expand All @@ -61,6 +78,9 @@ async def run(self):
async def async_poller(client, initial_response, deserialization_callback, polling_method):
"""Async Poller for long running operations.

.. deprecated:: 1.5.0
Use :class:`AsyncLROPoller` instead.

:param client: A pipeline service client.
:type client: ~azure.core.PipelineClient
:param initial_response: The initial call response
Expand All @@ -71,15 +91,95 @@ async def async_poller(client, initial_response, deserialization_callback, polli
:param polling_method: The polling strategy to adopt
:type polling_method: ~azure.core.polling.PollingMethod
"""
poller = AsyncLROPoller(client, initial_response, deserialization_callback, polling_method)
return await poller

# This implicit test avoids bringing in an explicit dependency on Model directly
try:
deserialization_callback = deserialization_callback.deserialize
except AttributeError:
pass

# Might raise a CloudError
polling_method.initialize(client, initial_response, deserialization_callback)
class AsyncLROPoller(Awaitable, Generic[PollingReturnType]):
"""Async poller for long running operations.

:param client: A pipeline service client
:type client: ~azure.core.PipelineClient
:param initial_response: The initial call response
:type initial_response:
~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse
:param deserialization_callback: A callback that takes a Response and return a deserialized object.
If a subclass of Model is given, this passes "deserialize" as callback.
:type deserialization_callback: callable or msrest.serialization.Model
:param polling_method: The polling strategy to adopt
:type polling_method: ~azure.core.polling.AsyncPollingMethod
"""

def __init__(
self,
client: Any,
initial_response: Any,
deserialization_callback: Callable,
polling_method: AsyncPollingMethod[PollingReturnType]
):
self._polling_method = polling_method
self._done = False

# This implicit test avoids bringing in an explicit dependency on Model directly
try:
deserialization_callback = deserialization_callback.deserialize # type: ignore
except AttributeError:
pass

self._polling_method.initialize(client, initial_response, deserialization_callback)

def continuation_token(self) -> str:
"""Return a continuation token that allows to restart the poller later.

:returns: An opaque continuation token
:rtype: str
"""
return self._polling_method.get_continuation_token()

@classmethod
def from_continuation_token(
cls,
polling_method: AsyncPollingMethod[PollingReturnType],
continuation_token: str,
**kwargs
) -> "AsyncLROPoller[PollingReturnType]":
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
return cls(client, initial_response, deserialization_callback, polling_method)

await polling_method.run()
return polling_method.resource()
def status(self) -> str:
"""Returns the current status string.

:returns: The current status string
:rtype: str
"""
return self._polling_method.status()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary for this PR - but it would be good if we could open up the contract regarding the expected return type status in future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've started creating an azure-core-typing package multiple times that would include all the protocols that we have (LRO, item paged, page paged etc.). I will take another stab at it as part of the next round of azure client library python design guideline updates.


async def result(self) -> PollingReturnType:
"""Return the result of the long running operation.

:returns: The deserialized resource of the long running operation, if one is available.
:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
"""
await self.wait()
return self._polling_method.resource()

def __await__(self) -> Generator[Any, None, PollingReturnType]:
return self.result().__await__()

async def wait(self) -> None:
"""Wait on the long running operation.

:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
"""
await self._polling_method.run()
self._done = True

def done(self) -> bool:
"""Check status of the long running operation.

:returns: 'True' if the process has completed, else 'False'.
:rtype: bool
"""
return self._done
67 changes: 57 additions & 10 deletions sdk/core/azure-core/azure/core/polling/_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,22 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import base64
import threading
import uuid
try:
from urlparse import urlparse # type: ignore # pylint: disable=unused-import
except ImportError:
from urllib.parse import urlparse

from typing import Any, Callable, Union, List, Optional, TypeVar, Generic, TYPE_CHECKING
from typing import Any, Callable, Union, List, Optional, Tuple, TypeVar, Generic
from azure.core.pipeline.transport._base import HttpResponse
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.common import with_current_context

if TYPE_CHECKING:
import requests
from msrest.serialization import Model # pylint: disable=unused-import
DeserializationCallbackType = Union[Model, Callable[[requests.Response], Model]]
PollingReturnType = TypeVar("PollingReturnType")


class PollingMethod(Generic[PollingReturnType]):
"""ABC class for polling method.
"""
Expand All @@ -64,6 +62,24 @@ def resource(self):
# type: () -> PollingReturnType
raise NotImplementedError("This method needs to be implemented")

def get_continuation_token(self):
# type() -> str
raise TypeError(
"Polling method '{}' doesn't support get_continuation_token".format(
self.__class__.__name__
)
)

@classmethod
def from_continuation_token(cls, continuation_token, **kwargs):
# type(str, Any) -> Tuple[Any, Any, Callable]
raise TypeError(
"Polling method '{}' doesn't support from_continuation_token".format(
cls.__name__
)
)


class NoPolling(PollingMethod):
"""An empty poller that returns the deserialized initial response.
"""
Expand All @@ -72,7 +88,7 @@ def __init__(self):
self._deserialization_callback = None

def initialize(self, _, initial_response, deserialization_callback):
# type: (Any, requests.Response, Callable) -> None
# type: (Any, Any, Callable) -> None
self._initial_response = initial_response
self._deserialization_callback = deserialization_callback

Expand Down Expand Up @@ -101,6 +117,22 @@ def resource(self):
# type: () -> Any
return self._deserialization_callback(self._initial_response)

def get_continuation_token(self):
# type() -> str
import pickle
return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii')

@classmethod
def from_continuation_token(cls, continuation_token, **kwargs):
# type(str, Any) -> Tuple
try:
deserialization_callback = kwargs["deserialization_callback"]
except KeyError:
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token")
import pickle
initial_response = pickle.loads(base64.b64decode(continuation_token))
return None, initial_response, deserialization_callback


class LROPoller(Generic[PollingReturnType]):
"""Poller for long running operations.
Expand All @@ -118,9 +150,7 @@ class LROPoller(Generic[PollingReturnType]):
"""

def __init__(self, client, initial_response, deserialization_callback, polling_method):
# type: (Any, HttpResponse, DeserializationCallbackType, PollingMethod) -> None
self._client = client
self._response = initial_response
# type: (Any, HttpResponse, Callable, PollingMethod[PollingReturnType]) -> None
self._callbacks = [] # type: List[Callable]
self._polling_method = polling_method

Expand All @@ -131,7 +161,7 @@ def __init__(self, client, initial_response, deserialization_callback, polling_m
pass

# Might raise a CloudError
self._polling_method.initialize(self._client, self._response, deserialization_callback)
self._polling_method.initialize(client, initial_response, deserialization_callback)

# Prepare thread execution
self._thread = None
Expand Down Expand Up @@ -166,6 +196,23 @@ def _start(self):
call(self._polling_method)
callbacks, self._callbacks = self._callbacks, []

def continuation_token(self):
# type: () -> str
"""Return a continuation token that allows to restart the poller later.

:returns: An opaque continuation token
:rtype: str
"""
return self._polling_method.get_continuation_token()

@classmethod
def from_continuation_token(cls, polling_method, continuation_token, **kwargs):
# type: (PollingMethod[PollingReturnType], str, Any) -> LROPoller[PollingReturnType]
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
return cls(client, initial_response, deserialization_callback, polling_method)

def status(self):
# type: () -> str
"""Returns the current status string.
Expand Down
25 changes: 25 additions & 0 deletions sdk/core/azure-core/azure/core/polling/base_polling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#
# --------------------------------------------------------------------------
import abc
import base64
import json
from typing import TYPE_CHECKING, Optional, Any, Union

Expand Down Expand Up @@ -447,6 +448,30 @@ def initialize(self, client, initial_response, deserialization_callback):
except OperationFailed as err:
raise HttpResponseError(response=initial_response.http_response, error=err)

def get_continuation_token(self):
# type() -> str
import pickle
return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii')

@classmethod
def from_continuation_token(cls, continuation_token, **kwargs):
# type(str, Any) -> Tuple
try:
client = kwargs["client"]
except KeyError:
raise ValueError("Need kwarg 'client' to be recreated from continuation_token")

try:
deserialization_callback = kwargs["deserialization_callback"]
except KeyError:
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token")

import pickle
annatisch marked this conversation as resolved.
Show resolved Hide resolved
initial_response = pickle.loads(base64.b64decode(continuation_token))
# Restore the transport in the context
initial_response.context.transport = client._pipeline._transport # pylint: disable=protected-access
return client, initial_response, deserialization_callback

def run(self):
try:
self._poll()
Expand Down
Loading