diff --git a/sdk/core/azure-core/azure/core/pipeline/__init__.py b/sdk/core/azure-core/azure/core/pipeline/__init__.py index 08854053a761..cec8606569ae 100644 --- a/sdk/core/azure-core/azure/core/pipeline/__init__.py +++ b/sdk/core/azure-core/azure/core/pipeline/__init__.py @@ -69,8 +69,21 @@ def __init__(self, transport, **kwargs): # pylint: disable=super-init-not-calle self.options = kwargs self._protected = ["transport", "options"] + def __getstate__(self): + state = self.__dict__.copy() + # Remove the unpicklable entries. + del state['transport'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + # Re-create the unpickable entries + self.transport = None + def __setitem__(self, key, item): - if key in self._protected: + # If reloaded from pickle, _protected might not be here until restored by pickle + # this explains the hasattr test + if hasattr(self, '_protected') and key in self._protected: raise ValueError("Context value {} cannot be overwritten.".format(key)) return super(PipelineContext, self).__setitem__(key, item) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 9b0089401513..9a40c616bdff 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -299,3 +299,14 @@ def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: :type pipeline: azure.core.pipeline """ return AioHttpStreamDownloadGenerator(pipeline, self) + + def __getstate__(self): + # Be sure body is loaded in memory, otherwise not pickable and let it throw + self.body() + + state = self.__dict__.copy() + # Remove the unpicklable entries. + state['internal_response'] = None # aiohttp response are not pickable (see headers comments) + from multidict import MultiDict # I know it's importable since aiohttp is loaded + state['headers'] = MultiDict(self.headers) # MultiDictProxy is not pickable + return state diff --git a/sdk/core/azure-core/azure/core/polling/__init__.py b/sdk/core/azure-core/azure/core/polling/__init__.py index 6d3109ab457e..660c3e586090 100644 --- a/sdk/core/azure-core/azure/core/polling/__init__.py +++ b/sdk/core/azure-core/azure/core/polling/__init__.py @@ -31,5 +31,5 @@ #pylint: disable=unused-import if sys.version_info >= (3, 5, 2): # Not executed on old Python, no syntax error - from ._async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller - __all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller'] + from ._async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller, AsyncLROPoller + __all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller', 'AsyncLROPoller'] diff --git a/sdk/core/azure-core/azure/core/polling/_async_poller.py b/sdk/core/azure-core/azure/core/polling/_async_poller.py index cb71d054cfa4..3a792a280a14 100644 --- a/sdk/core/azure-core/azure/core/polling/_async_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_async_poller.py @@ -23,7 +23,9 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from typing import Generic, TypeVar, Any +from collections.abc import Awaitable +from typing import Callable, Any, Tuple, Generic, TypeVar, Generator + from ._poller import NoPolling as _NoPolling @@ -48,6 +50,21 @@ def finished(self) -> bool: def resource(self) -> PollingReturnType: raise NotImplementedError("This method needs to be implemented") + def get_continuation_token(self) -> str: + raise TypeError( + "Polling method '{}' doesn't support get_continuation_token".format( + self.__class__.__name__ + ) + ) + + @classmethod + def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any, Any, Callable]: + raise TypeError( + "Polling method '{}' doesn't support from_continuation_token".format( + cls.__name__ + ) + ) + class AsyncNoPolling(_NoPolling): """An empty async poller that returns the deserialized initial response. @@ -61,6 +78,9 @@ async def run(self): async def async_poller(client, initial_response, deserialization_callback, polling_method): """Async Poller for long running operations. + .. deprecated:: 1.5.0 + Use :class:`AsyncLROPoller` instead. + :param client: A pipeline service client. :type client: ~azure.core.PipelineClient :param initial_response: The initial call response @@ -71,15 +91,100 @@ async def async_poller(client, initial_response, deserialization_callback, polli :param polling_method: The polling strategy to adopt :type polling_method: ~azure.core.polling.PollingMethod """ + poller = AsyncLROPoller(client, initial_response, deserialization_callback, polling_method) + return await poller + + +class AsyncLROPoller(Awaitable, Generic[PollingReturnType]): + """Async poller for long running operations. + + :param client: A pipeline service client + :type client: ~azure.core.PipelineClient + :param initial_response: The initial call response + :type initial_response: + ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse + :param deserialization_callback: A callback that takes a Response and return a deserialized object. + If a subclass of Model is given, this passes "deserialize" as callback. + :type deserialization_callback: callable or msrest.serialization.Model + :param polling_method: The polling strategy to adopt + :type polling_method: ~azure.core.polling.AsyncPollingMethod + """ + + def __init__( + self, + client: Any, + initial_response: Any, + deserialization_callback: Callable, + polling_method: AsyncPollingMethod[PollingReturnType] + ): + self._polling_method = polling_method + self._done = False + + # This implicit test avoids bringing in an explicit dependency on Model directly + try: + deserialization_callback = deserialization_callback.deserialize # type: ignore + except AttributeError: + pass + + self._polling_method.initialize(client, initial_response, deserialization_callback) + + def polling_method(self) -> AsyncPollingMethod[PollingReturnType]: + """Return the polling method associated to this poller. + """ + return self._polling_method + + def continuation_token(self) -> str: + """Return a continuation token that allows to restart the poller later. + + :returns: An opaque continuation token + :rtype: str + """ + return self._polling_method.get_continuation_token() + + @classmethod + def from_continuation_token( + cls, + polling_method: AsyncPollingMethod[PollingReturnType], + continuation_token: str, + **kwargs + ) -> "AsyncLROPoller[PollingReturnType]": + client, initial_response, deserialization_callback = polling_method.from_continuation_token( + continuation_token, **kwargs + ) + return cls(client, initial_response, deserialization_callback, polling_method) - # This implicit test avoids bringing in an explicit dependency on Model directly - try: - deserialization_callback = deserialization_callback.deserialize - except AttributeError: - pass + def status(self) -> str: + """Returns the current status string. + + :returns: The current status string + :rtype: str + """ + return self._polling_method.status() - # Might raise a CloudError - polling_method.initialize(client, initial_response, deserialization_callback) + async def result(self) -> PollingReturnType: + """Return the result of the long running operation. - await polling_method.run() - return polling_method.resource() + :returns: The deserialized resource of the long running operation, if one is available. + :raises ~azure.core.exceptions.HttpResponseError: Server problem with the query. + """ + await self.wait() + return self._polling_method.resource() + + def __await__(self) -> Generator[Any, None, PollingReturnType]: + return self.result().__await__() + + async def wait(self) -> None: + """Wait on the long running operation. + + :raises ~azure.core.exceptions.HttpResponseError: Server problem with the query. + """ + await self._polling_method.run() + self._done = True + + def done(self) -> bool: + """Check status of the long running operation. + + :returns: 'True' if the process has completed, else 'False'. + :rtype: bool + """ + return self._done diff --git a/sdk/core/azure-core/azure/core/polling/_poller.py b/sdk/core/azure-core/azure/core/polling/_poller.py index 033af831be9b..4dff2ab220f6 100644 --- a/sdk/core/azure-core/azure/core/polling/_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_poller.py @@ -23,6 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +import base64 import threading import uuid try: @@ -30,17 +31,14 @@ except ImportError: from urllib.parse import urlparse -from typing import Any, Callable, Union, List, Optional, TypeVar, Generic, TYPE_CHECKING +from typing import Any, Callable, Union, List, Optional, Tuple, TypeVar, Generic from azure.core.pipeline.transport._base import HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.common import with_current_context -if TYPE_CHECKING: - import requests - from msrest.serialization import Model # pylint: disable=unused-import - DeserializationCallbackType = Union[Model, Callable[[requests.Response], Model]] PollingReturnType = TypeVar("PollingReturnType") + class PollingMethod(Generic[PollingReturnType]): """ABC class for polling method. """ @@ -64,6 +62,24 @@ def resource(self): # type: () -> PollingReturnType raise NotImplementedError("This method needs to be implemented") + def get_continuation_token(self): + # type() -> str + raise TypeError( + "Polling method '{}' doesn't support get_continuation_token".format( + self.__class__.__name__ + ) + ) + + @classmethod + def from_continuation_token(cls, continuation_token, **kwargs): + # type(str, Any) -> Tuple[Any, Any, Callable] + raise TypeError( + "Polling method '{}' doesn't support from_continuation_token".format( + cls.__name__ + ) + ) + + class NoPolling(PollingMethod): """An empty poller that returns the deserialized initial response. """ @@ -72,7 +88,7 @@ def __init__(self): self._deserialization_callback = None def initialize(self, _, initial_response, deserialization_callback): - # type: (Any, requests.Response, Callable) -> None + # type: (Any, Any, Callable) -> None self._initial_response = initial_response self._deserialization_callback = deserialization_callback @@ -101,6 +117,22 @@ def resource(self): # type: () -> Any return self._deserialization_callback(self._initial_response) + def get_continuation_token(self): + # type() -> str + import pickle + return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii') + + @classmethod + def from_continuation_token(cls, continuation_token, **kwargs): + # type(str, Any) -> Tuple + try: + deserialization_callback = kwargs["deserialization_callback"] + except KeyError: + raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") + import pickle + initial_response = pickle.loads(base64.b64decode(continuation_token)) + return None, initial_response, deserialization_callback + class LROPoller(Generic[PollingReturnType]): """Poller for long running operations. @@ -118,9 +150,7 @@ class LROPoller(Generic[PollingReturnType]): """ def __init__(self, client, initial_response, deserialization_callback, polling_method): - # type: (Any, HttpResponse, DeserializationCallbackType, PollingMethod) -> None - self._client = client - self._response = initial_response + # type: (Any, HttpResponse, Callable, PollingMethod[PollingReturnType]) -> None self._callbacks = [] # type: List[Callable] self._polling_method = polling_method @@ -131,7 +161,7 @@ def __init__(self, client, initial_response, deserialization_callback, polling_m pass # Might raise a CloudError - self._polling_method.initialize(self._client, self._response, deserialization_callback) + self._polling_method.initialize(client, initial_response, deserialization_callback) # Prepare thread execution self._thread = None @@ -166,6 +196,29 @@ def _start(self): call(self._polling_method) callbacks, self._callbacks = self._callbacks, [] + def polling_method(self): + # type: () -> PollingMethod[PollingReturnType] + """Return the polling method associated to this poller. + """ + return self._polling_method + + def continuation_token(self): + # type: () -> str + """Return a continuation token that allows to restart the poller later. + + :returns: An opaque continuation token + :rtype: str + """ + return self._polling_method.get_continuation_token() + + @classmethod + def from_continuation_token(cls, polling_method, continuation_token, **kwargs): + # type: (PollingMethod[PollingReturnType], str, Any) -> LROPoller[PollingReturnType] + client, initial_response, deserialization_callback = polling_method.from_continuation_token( + continuation_token, **kwargs + ) + return cls(client, initial_response, deserialization_callback, polling_method) + def status(self): # type: () -> str """Returns the current status string. diff --git a/sdk/core/azure-core/azure/core/polling/base_polling.py b/sdk/core/azure-core/azure/core/polling/base_polling.py index af8c3b037c9b..7ddcd6ef0007 100644 --- a/sdk/core/azure-core/azure/core/polling/base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/base_polling.py @@ -24,6 +24,7 @@ # # -------------------------------------------------------------------------- import abc +import base64 import json from typing import TYPE_CHECKING, Optional, Any, Union @@ -447,6 +448,30 @@ def initialize(self, client, initial_response, deserialization_callback): except OperationFailed as err: raise HttpResponseError(response=initial_response.http_response, error=err) + def get_continuation_token(self): + # type() -> str + import pickle + return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii') + + @classmethod + def from_continuation_token(cls, continuation_token, **kwargs): + # type(str, Any) -> Tuple + try: + client = kwargs["client"] + except KeyError: + raise ValueError("Need kwarg 'client' to be recreated from continuation_token") + + try: + deserialization_callback = kwargs["deserialization_callback"] + except KeyError: + raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") + + import pickle + initial_response = pickle.loads(base64.b64decode(continuation_token)) + # Restore the transport in the context + initial_response.context.transport = client._pipeline._transport # pylint: disable=protected-access + return client, initial_response, deserialization_callback + def run(self): try: self._poll() diff --git a/sdk/core/azure-core/tests/azure_core_asynctests/test_async_base_polling.py b/sdk/core/azure-core/tests/azure_core_asynctests/test_async_base_polling.py index ff31a98fdbc8..b1749806b84c 100644 --- a/sdk/core/azure-core/tests/azure_core_asynctests/test_async_base_polling.py +++ b/sdk/core/azure-core/tests/azure_core_asynctests/test_async_base_polling.py @@ -42,7 +42,7 @@ from azure.core.polling import async_poller from azure.core.exceptions import DecodeError, HttpResponseError from azure.core import AsyncPipelineClient -from azure.core.pipeline import PipelineResponse, AsyncPipeline +from azure.core.pipeline import PipelineResponse, AsyncPipeline, PipelineContext from azure.core.pipeline.transport import AsyncioRequestsTransportResponse, AsyncHttpTransport from azure.core.polling.async_base_polling import ( @@ -87,6 +87,12 @@ async def mock_run(client_self, request, **kwargs): CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) +@pytest.fixture +def client(): + # The poller itself don't use it, so we don't need something functionnal + return AsyncPipelineClient("https://baseurl") + + @pytest.fixture def async_pipeline_client_builder(): """Build a client that use the "send" callback as final transport layer @@ -118,6 +124,42 @@ def cb(pipeline_response): return cb +@pytest.fixture +def polling_response(): + polling = AsyncLROBasePolling() + headers = {} + + response = Response() + response.headers = headers + response.status_code = 200 + + polling._pipeline_response = PipelineResponse( + None, + AsyncioRequestsTransportResponse( + None, + response, + ), + PipelineContext(None) + ) + polling._initial_response = polling._pipeline_response + return polling, headers + + +def test_base_polling_continuation_token(client, polling_response): + polling, _ = polling_response + + continuation_token = polling.get_continuation_token() + assert isinstance(continuation_token, str) + + polling_args = AsyncLROBasePolling.from_continuation_token( + continuation_token, + deserialization_callback="deserialization_callback", + client=client, + ) + new_polling = AsyncLROBasePolling() + new_polling.initialize(*polling_args) + + @pytest.mark.asyncio async def test_post(async_pipeline_client_builder, deserialization_cb): diff --git a/sdk/core/azure-core/tests/azure_core_asynctests/test_polling.py b/sdk/core/azure-core/tests/azure_core_asynctests/test_polling.py new file mode 100644 index 000000000000..31f3ef924952 --- /dev/null +++ b/sdk/core/azure-core/tests/azure_core_asynctests/test_polling.py @@ -0,0 +1,199 @@ +#-------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +#-------------------------------------------------------------------------- +import asyncio +import time +try: + from unittest import mock +except ImportError: + import mock + +import pytest + +from azure.core import AsyncPipelineClient +from azure.core.polling import * +from msrest.serialization import Model + + +@pytest.fixture +def client(): + # The poller itself don't use it, so we don't need something functionnal + return AsyncPipelineClient("https://baseurl") + + +@pytest.mark.asyncio +async def test_no_polling(client): + no_polling = AsyncNoPolling() + + initial_response = "initial response" + def deserialization_cb(response): + assert response == initial_response + return "Treated: "+response + + no_polling.initialize(client, initial_response, deserialization_cb) + await no_polling.run() # Should no raise and do nothing + assert no_polling.status() == "succeeded" + assert no_polling.finished() + assert no_polling.resource() == "Treated: "+initial_response + + continuation_token = no_polling.get_continuation_token() + assert isinstance(continuation_token, str) + + no_polling_revived_args = NoPolling.from_continuation_token( + continuation_token, + deserialization_callback=deserialization_cb, + client=client + ) + no_polling_revived = NoPolling() + no_polling_revived.initialize(*no_polling_revived_args) + assert no_polling_revived.status() == "succeeded" + assert no_polling_revived.finished() + assert no_polling_revived.resource() == "Treated: "+initial_response + + +class PollingTwoSteps(AsyncPollingMethod): + """An empty poller that returns the deserialized initial response. + """ + def __init__(self, sleep=0): + self._initial_response = None + self._deserialization_callback = None + self._sleep = sleep + + def initialize(self, _, initial_response, deserialization_callback): + self._initial_response = initial_response + self._deserialization_callback = deserialization_callback + self._finished = False + + async def run(self): + """Empty run, no polling. + """ + self._finished = True + await asyncio.sleep(self._sleep) # Give me time to add callbacks! + + def status(self): + """Return the current status as a string. + :rtype: str + """ + return "succeeded" if self._finished else "running" + + def finished(self): + """Is this polling finished? + :rtype: bool + """ + return self._finished + + def resource(self): + return self._deserialization_callback(self._initial_response) + + def get_continuation_token(self): + return self._initial_response + + @classmethod + def from_continuation_token(cls, continuation_token, **kwargs): + # type(str, Any) -> Tuple + initial_response = continuation_token + deserialization_callback = kwargs['deserialization_callback'] + return None, initial_response, deserialization_callback + + +@pytest.mark.asyncio +async def test_poller(client): + + # Same the poller itself doesn't care about the initial_response, and there is no type constraint here + initial_response = "Initial response" + + # Same for deserialization_callback, just pass to the polling_method + def deserialization_callback(response): + assert response == initial_response + return "Treated: "+response + + method = AsyncNoPolling() + + raw_poller = AsyncLROPoller(client, initial_response, deserialization_callback, method) + poller = asyncio.ensure_future(raw_poller.result()) + + done_cb = mock.MagicMock() + poller.add_done_callback(done_cb) + + result = await poller + assert poller.done() + assert result == "Treated: "+initial_response + assert raw_poller.status() == "succeeded" + assert raw_poller.polling_method() is method + done_cb.assert_called_once_with(poller) + + # Test with a basic Model + poller = AsyncLROPoller(client, initial_response, Model, method) + assert poller._polling_method._deserialization_callback == Model.deserialize + + # Test poller that method do a run + method = PollingTwoSteps(sleep=1) + raw_poller = AsyncLROPoller(client, initial_response, deserialization_callback, method) + poller = asyncio.ensure_future(raw_poller.result()) + + done_cb = mock.MagicMock() + done_cb2 = mock.MagicMock() + poller.add_done_callback(done_cb) + poller.remove_done_callback(done_cb2) + + result = await poller + assert result == "Treated: "+initial_response + assert raw_poller.status() == "succeeded" + done_cb.assert_called_once_with(poller) + done_cb2.assert_not_called() + + # Test continuation token + cont_token = raw_poller.continuation_token() + + method = PollingTwoSteps(sleep=1) + new_poller = AsyncLROPoller.from_continuation_token( + continuation_token=cont_token, + client=client, + initial_response=initial_response, + deserialization_callback=deserialization_callback, + polling_method=method + ) + result = await new_poller.result() + assert result == "Treated: "+initial_response + assert new_poller.status() == "succeeded" + + +@pytest.mark.asyncio +async def test_broken_poller(client): + + class NoPollingError(PollingTwoSteps): + async def run(self): + raise ValueError("Something bad happened") + + initial_response = "Initial response" + def deserialization_callback(response): + return "Treated: "+response + + method = NoPollingError() + poller = AsyncLROPoller(client, initial_response, deserialization_callback, method) + + with pytest.raises(ValueError) as excinfo: + await poller.result() + assert "Something bad happened" in str(excinfo.value) diff --git a/sdk/core/azure-core/tests/test_base_polling.py b/sdk/core/azure-core/tests/test_base_polling.py index 6446014be67c..840ab4cfe880 100644 --- a/sdk/core/azure-core/tests/test_base_polling.py +++ b/sdk/core/azure-core/tests/test_base_polling.py @@ -29,6 +29,7 @@ import types import platform import unittest +import six try: from unittest import mock except ImportError: @@ -43,7 +44,7 @@ from azure.core.polling import LROPoller from azure.core.exceptions import DecodeError, HttpResponseError from azure.core import PipelineClient -from azure.core.pipeline import PipelineResponse, Pipeline +from azure.core.pipeline import PipelineResponse, Pipeline, PipelineContext from azure.core.pipeline.transport import RequestsTransportResponse, HttpTransport from azure.core.polling.base_polling import LROBasePolling @@ -86,6 +87,12 @@ def mock_run(client_self, request, **kwargs): CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) +@pytest.fixture +def client(): + # The poller itself don't use it, so we don't need something functionnal + return PipelineClient("https://baseurl") + + @pytest.fixture def pipeline_client_builder(): """Build a client that use the "send" callback as final transport layer @@ -116,6 +123,7 @@ def cb(pipeline_response): return json.loads(pipeline_response.http_response.text()) return cb + @pytest.fixture def polling_response(): polling = LROBasePolling() @@ -123,6 +131,7 @@ def polling_response(): response = Response() response.headers = headers + response.status_code = 200 polling._pipeline_response = PipelineResponse( None, @@ -130,10 +139,27 @@ def polling_response(): None, response, ), - None # context + PipelineContext(None) ) + polling._initial_response = polling._pipeline_response return polling, headers + +def test_base_polling_continuation_token(client, polling_response): + polling, _ = polling_response + + continuation_token = polling.get_continuation_token() + assert isinstance(continuation_token, six.string_types) + + polling_args = LROBasePolling.from_continuation_token( + continuation_token, + deserialization_callback="deserialization_callback", + client=client, + ) + new_polling = LROBasePolling() + new_polling.initialize(*polling_args) + + def test_delay_extraction_int(polling_response): polling, headers = polling_response diff --git a/sdk/core/azure-core/tests/test_polling.py b/sdk/core/azure-core/tests/test_polling.py index 597db5221865..02626bcad2cb 100644 --- a/sdk/core/azure-core/tests/test_polling.py +++ b/sdk/core/azure-core/tests/test_polling.py @@ -30,11 +30,17 @@ import mock import pytest +import six +from azure.core import PipelineClient from azure.core.polling import * -from msrest.service_client import ServiceClient from msrest.serialization import Model -from msrest.configuration import Configuration + + +@pytest.fixture +def client(): + # The poller itself don't use it, so we don't need something functionnal + return PipelineClient("https://baseurl") def test_abc_polling(): @@ -55,7 +61,14 @@ def test_abc_polling(): with pytest.raises(NotImplementedError): abc_polling.resource() -def test_no_polling(): + with pytest.raises(TypeError): + abc_polling.get_continuation_token() + + with pytest.raises(TypeError): + abc_polling.from_continuation_token("token") + + +def test_no_polling(client): no_polling = NoPolling() initial_response = "initial response" @@ -63,12 +76,26 @@ def deserialization_cb(response): assert response == initial_response return "Treated: "+response - no_polling.initialize(None, initial_response, deserialization_cb) + no_polling.initialize(client, initial_response, deserialization_cb) no_polling.run() # Should no raise and do nothing assert no_polling.status() == "succeeded" assert no_polling.finished() assert no_polling.resource() == "Treated: "+initial_response + continuation_token = no_polling.get_continuation_token() + assert isinstance(continuation_token, six.string_types) + + no_polling_revived_args = NoPolling.from_continuation_token( + continuation_token, + deserialization_callback=deserialization_cb, + client=client + ) + no_polling_revived = NoPolling() + no_polling_revived.initialize(*no_polling_revived_args) + assert no_polling_revived.status() == "succeeded" + assert no_polling_revived.finished() + assert no_polling_revived.resource() == "Treated: "+initial_response + class PollingTwoSteps(PollingMethod): """An empty poller that returns the deserialized initial response. @@ -104,11 +131,16 @@ def finished(self): def resource(self): return self._deserialization_callback(self._initial_response) -@pytest.fixture -def client(): - # We need a ServiceClient instance, but the poller itself don't use it, so we don't need - # Something functionnal - return ServiceClient(None, Configuration("http://example.org")) + def get_continuation_token(self): + return self._initial_response + + @classmethod + def from_continuation_token(cls, continuation_token, **kwargs): + # type(str, Any) -> Tuple + initial_response = continuation_token + deserialization_callback = kwargs['deserialization_callback'] + return None, initial_response, deserialization_callback + def test_poller(client): @@ -131,6 +163,7 @@ def deserialization_callback(response): assert poller.done() assert result == "Treated: "+initial_response assert poller.status() == "succeeded" + assert poller.polling_method() is method done_cb.assert_called_once_with(method) # Test with a basic Model @@ -156,6 +189,22 @@ def deserialization_callback(response): poller.remove_done_callback(done_cb) assert "Process is complete" in str(excinfo.value) + # Test continuation token + cont_token = poller.continuation_token() + + method = PollingTwoSteps(sleep=1) + new_poller = LROPoller.from_continuation_token( + continuation_token=cont_token, + client=client, + initial_response=initial_response, + deserialization_callback=deserialization_callback, + polling_method=method + ) + result = new_poller.result() + assert result == "Treated: "+initial_response + assert new_poller.status() == "succeeded" + + def test_broken_poller(client): class NoPollingError(PollingTwoSteps):