Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[formrecognizer] reduce time for recorded tests runs #11970

Merged
merged 12 commits into from
Jun 15, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ def __init__(self, endpoint, credential, **kwargs):
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we documenting this polling_interval kwarg to the client for users? I think a good idea could be to call it _polling_interval if we're the only people expecting the use it, but I don't feel too strongly about this

Copy link
Member Author

@kristapratico kristapratico Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an azure-core setting, so hopefully once Xiang documents all the core keywords we'll be able to link to that :)

Edit: this should be available to users

self._client = FormRecognizer(
endpoint=endpoint,
credential=credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -111,7 +113,7 @@ def begin_recognize_receipts(self, receipt, **kwargs):
:caption: Recognize US sales receipt fields.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -162,7 +164,7 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs):
:caption: Recognize US sales receipt fields from a URL.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down Expand Up @@ -210,7 +212,7 @@ def begin_recognize_content(self, form, **kwargs):
:caption: Recognize text and content/layout information from a form.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -246,7 +248,7 @@ def begin_recognize_content_from_url(self, form_url, **kwargs):
:raises ~azure.core.exceptions.HttpResponseError:
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

return self._client.begin_analyze_layout_async(
Expand Down Expand Up @@ -296,7 +298,7 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs):
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -348,7 +350,7 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs):
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.core.tracing.decorator import distributed_trace
from azure.core.polling import LROPoller
from azure.core.polling.base_polling import LROBasePolling
from azure.core.pipeline import Pipeline
from ._generated._form_recognizer_client import FormRecognizerClient as FormRecognizer
from ._generated.models import (
TrainRequest,
Expand All @@ -26,7 +27,7 @@
CopyOperationResult,
CopyAuthorizationResult
)
from ._helpers import error_map, get_authentication_policy, POLLING_INTERVAL
from ._helpers import error_map, get_authentication_policy, POLLING_INTERVAL, TransportWrapper
from ._models import (
CustomFormModelInfo,
AccountProperties,
Expand Down Expand Up @@ -78,11 +79,13 @@ def __init__(self, endpoint, credential, **kwargs):
self._endpoint = endpoint
self._credential = credential
authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
self._client = FormRecognizer(
endpoint=self._endpoint,
credential=self._credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -129,7 +132,7 @@ def callback(raw_response):

cls = kwargs.pop("cls", None)
continuation_token = kwargs.pop("continuation_token", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
deserialization_callback = cls if cls else callback

if continuation_token:
Expand Down Expand Up @@ -339,7 +342,7 @@ def begin_copy_model(
if not model_id:
raise ValueError("model_id cannot be None or empty.")

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
Expand Down Expand Up @@ -371,11 +374,20 @@ def get_form_recognizer_client(self, **kwargs):
:rtype: ~azure.ai.formrecognizer.FormRecognizerClient
:return: A FormRecognizerClient
"""
return FormRecognizerClient(

_pipeline = Pipeline(
transport=TransportWrapper(self._client._client._pipeline._transport),
policies=self._client._client._pipeline._impl_policies
) # type: Pipeline
client = FormRecognizerClient(
endpoint=self._endpoint,
credential=self._credential,
pipeline=_pipeline,
**kwargs
)
# need to share config, but can't pass as a keyword into client
client._client._config = self._client._client._config
return client

def close(self):
# type: () -> None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import six
from azure.core.credentials import AzureKeyCredential
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
from azure.core.pipeline.transport import HttpTransport
from azure.core.exceptions import (
ResourceNotFoundError,
ResourceExistsError,
Expand All @@ -24,6 +25,30 @@
}


class TransportWrapper(HttpTransport):
"""Wrapper class that ensures that an inner client created
by a `get_client` method does not close the outer transport for the parent
when used in a context manager.
"""
def __init__(self, transport):
self._transport = transport

def send(self, request, **kwargs):
return self._transport.send(request, **kwargs)

def open(self):
pass

def close(self):
pass

def __enter__(self):
pass

def __exit__(self, *args): # pylint: disable=arguments-differ
pass


def get_authentication_policy(credential):
authentication_policy = None
if credential is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def __init__(
) -> None:

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
self._client = FormRecognizer(
endpoint=endpoint,
credential=credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -119,7 +121,7 @@ async def begin_recognize_receipts(
:caption: Recognize US sales receipt fields.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -176,7 +178,7 @@ async def begin_recognize_receipts_from_url(
:caption: Recognize US sales receipt fields from a URL.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down Expand Up @@ -230,7 +232,7 @@ async def begin_recognize_content(
:caption: Recognize text and content/layout information from a form.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -268,7 +270,7 @@ async def begin_recognize_content_from_url(self, form_url: str, **kwargs: Any) -
:raises ~azure.core.exceptions.HttpResponseError:
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
return await self._client.begin_analyze_layout_async( # type: ignore
file_stream={"source": form_url},
Expand Down Expand Up @@ -324,7 +326,7 @@ async def begin_recognize_custom_forms(
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -385,7 +387,7 @@ async def begin_recognize_custom_forms_from_url(
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
TYPE_CHECKING,
)
from azure.core.polling import AsyncLROPoller
from azure.core.pipeline import AsyncPipeline
from azure.core.polling.async_base_polling import AsyncLROBasePolling
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from ._form_recognizer_client_async import FormRecognizerClient
from ._helpers_async import AsyncTransportWrapper
from .._generated.aio._form_recognizer_client_async import FormRecognizerClient as FormRecognizer
from .._generated.models import (
TrainRequest,
Expand Down Expand Up @@ -81,13 +83,14 @@ def __init__(
) -> None:
self._endpoint = endpoint
self._credential = credential

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
self._client = FormRecognizer(
endpoint=self._endpoint,
credential=self._credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -138,7 +141,7 @@ def callback(raw_response):

cls = kwargs.pop("cls", None)
continuation_token = kwargs.pop("continuation_token", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
deserialization_callback = cls if cls else callback

if continuation_token:
Expand Down Expand Up @@ -361,7 +364,7 @@ async def begin_copy_model(
raise ValueError("model_id cannot be None or empty.")

continuation_token = kwargs.pop("continuation_token", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
copy_result = self._client._deserialize(CopyOperationResult, raw_response)
Expand Down Expand Up @@ -395,11 +398,19 @@ def get_form_recognizer_client(self, **kwargs: Any) -> FormRecognizerClient:
:rtype: ~azure.ai.formrecognizer.aio.FormRecognizerClient
:return: A FormRecognizerClient
"""
return FormRecognizerClient(
_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._client._client._pipeline._transport),
policies=self._client._client._pipeline._impl_policies
) # type: AsyncPipeline
client = FormRecognizerClient(
endpoint=self._endpoint,
credential=self._credential,
pipeline=_pipeline,
**kwargs
)
# need to share config, but can't pass as a keyword into client
client._client._config = self._client._client._config
return client

async def __aenter__(self) -> "FormTrainingClient":
await self._client.__aenter__()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# coding=utf-8
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from azure.core.pipeline.transport import AsyncHttpTransport


class AsyncTransportWrapper(AsyncHttpTransport):
"""Wrapper class that ensures that an inner client created
by a `get_client` method does not close the outer transport for the parent
when used in a context manager.
"""
def __init__(self, async_transport):
self._transport = async_transport

async def send(self, request, **kwargs):
return await self._transport.send(request, **kwargs)

async def open(self):
pass

async def close(self):
pass

async def __aenter__(self):
pass

async def __aexit__(self, *args): # pylint: disable=arguments-differ
pass
Loading