Skip to content

Commit

Permalink
run black on code (Azure#21033)
Browse files Browse the repository at this point in the history
  • Loading branch information
catalinaperalta authored Oct 4, 2021
1 parent dc22490 commit d332e47
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def validate_api_version(api_version, client_kind):
api_version = DocumentAnalysisApiVersion(api_version)
err_message += (
"\nAPI version '{}' is only available for "
"DocumentAnalysisClient and DocumentModelAdministrationClient.".format(api_version)
"DocumentAnalysisClient and DocumentModelAdministrationClient.".format(
api_version
)
)
except ValueError:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,15 @@ class DocumentAnalysisClient(FormRecognizerClientBase):

def __init__(self, endpoint, credential, **kwargs):
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None
api_version = kwargs.pop("api_version", DocumentAnalysisApiVersion.V2021_09_30_PREVIEW)
api_version = kwargs.pop(
"api_version", DocumentAnalysisApiVersion.V2021_09_30_PREVIEW
)
super(DocumentAnalysisClient, self).__init__(
endpoint=endpoint, credential=credential, api_version=api_version, client_kind="document", **kwargs
endpoint=endpoint,
credential=credential,
api_version=api_version,
client_kind="document",
**kwargs
)

def _analyze_document_callback(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
from azure.core.paging import ItemPaged
from ._helpers import TransportWrapper
from ._api_versions import DocumentAnalysisApiVersion
from ._polling import DocumentModelAdministrationPolling, DocumentModelAdministrationLROPoller
from ._polling import (
DocumentModelAdministrationPolling,
DocumentModelAdministrationLROPoller,
)
from ._form_base_client import FormRecognizerClientBase
from ._document_analysis_client import DocumentAnalysisClient
from ._models import (
DocumentModel,
DocumentModelInfo,
ModelOperation,
ModelOperationInfo,
AccountInfo
AccountInfo,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -81,9 +84,15 @@ class DocumentModelAdministrationClient(FormRecognizerClientBase):

def __init__(self, endpoint, credential, **kwargs):
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None
api_version = kwargs.pop("api_version", DocumentAnalysisApiVersion.V2021_09_30_PREVIEW)
api_version = kwargs.pop(
"api_version", DocumentAnalysisApiVersion.V2021_09_30_PREVIEW
)
super(DocumentModelAdministrationClient, self).__init__(
endpoint=endpoint, credential=credential, api_version=api_version, client_kind="document", **kwargs
endpoint=endpoint,
credential=credential,
api_version=api_version,
client_kind="document",
**kwargs
)

@distributed_trace
Expand Down Expand Up @@ -121,8 +130,12 @@ def begin_build_model(self, source, **kwargs):
"""

def callback(raw_response, _, headers): # pylint: disable=unused-argument
op_response = self._deserialize(self._generated_models.GetOperationResponse, raw_response)
model_info = self._deserialize(self._generated_models.ModelInfo, op_response.result)
op_response = self._deserialize(
self._generated_models.GetOperationResponse, raw_response
)
model_info = self._deserialize(
self._generated_models.ModelInfo, op_response.result
)
return DocumentModel._from_generated(model_info)

description = kwargs.pop("description", None)
Expand All @@ -148,7 +161,9 @@ def callback(raw_response, _, headers): # pylint: disable=unused-argument
cls=cls,
continuation_token=continuation_token,
polling=LROBasePolling(
timeout=polling_interval, lro_algorithms=[DocumentModelAdministrationPolling()], **kwargs
timeout=polling_interval,
lro_algorithms=[DocumentModelAdministrationPolling()],
**kwargs
),
**kwargs
)
Expand Down Expand Up @@ -185,8 +200,12 @@ def begin_create_composed_model(self, model_ids, **kwargs):
def _compose_callback(
raw_response, _, headers
): # pylint: disable=unused-argument
op_response = self._deserialize(self._generated_models.GetOperationResponse, raw_response)
model_info = self._deserialize(self._generated_models.ModelInfo, op_response.result)
op_response = self._deserialize(
self._generated_models.GetOperationResponse, raw_response
)
model_info = self._deserialize(
self._generated_models.ModelInfo, op_response.result
)
return DocumentModel._from_generated(model_info)

model_id = kwargs.pop("model_id", None)
Expand All @@ -206,7 +225,9 @@ def _compose_callback(
component_models=[
self._generated_models.ComponentModelInfo(model_id=model_id)
for model_id in model_ids
] if model_ids else []
]
if model_ids
else [],
),
cls=kwargs.pop("cls", _compose_callback),
polling=LROBasePolling(
Expand Down Expand Up @@ -242,15 +263,13 @@ def get_copy_authorization(self, **kwargs):

response = self._client.authorize_copy_document_model(
authorize_copy_request=self._generated_models.AuthorizeCopyRequest(
model_id=model_id,
description=description
model_id=model_id, description=description
),
**kwargs
)
target = response.serialize() # type: ignore
return target


@distributed_trace
def begin_copy_model(
self,
Expand Down Expand Up @@ -287,8 +306,12 @@ def begin_copy_model(
"""

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
op_response = self._deserialize(self._generated_models.GetOperationResponse, raw_response)
model_info = self._deserialize(self._generated_models.ModelInfo, op_response.result)
op_response = self._deserialize(
self._generated_models.GetOperationResponse, raw_response
)
model_info = self._deserialize(
self._generated_models.ModelInfo, op_response.result
)
return DocumentModel._from_generated(model_info)

if not model_id:
Expand All @@ -308,10 +331,14 @@ def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
access_token=target["accessToken"],
expiration_date_time=target["expirationDateTime"],
target_model_location=target["targetModelLocation"],
) if target else None,
)
if target
else None,
cls=kwargs.pop("cls", _copy_callback),
polling=LROBasePolling(
timeout=polling_interval, lro_algorithms=[DocumentModelAdministrationPolling()], **kwargs
timeout=polling_interval,
lro_algorithms=[DocumentModelAdministrationPolling()],
**kwargs
),
continuation_token=continuation_token,
**kwargs
Expand Down Expand Up @@ -365,10 +392,7 @@ def list_models(self, **kwargs):
return self._client.get_models(
cls=kwargs.pop(
"cls",
lambda objs: [
DocumentModelInfo._from_generated(x)
for x in objs
],
lambda objs: [DocumentModelInfo._from_generated(x) for x in objs],
),
**kwargs
)
Expand Down Expand Up @@ -418,10 +442,7 @@ def get_model(self, model_id, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

response = self._client.get_model(
model_id=model_id,
**kwargs
)
response = self._client.get_model(model_id=model_id, **kwargs)
return DocumentModel._from_generated(response)

@distributed_trace
Expand Down Expand Up @@ -450,10 +471,7 @@ def list_operations(self, **kwargs):
return self._client.get_operations(
cls=kwargs.pop(
"cls",
lambda objs: [
ModelOperationInfo._from_generated(x)
for x in objs
],
lambda objs: [ModelOperationInfo._from_generated(x) for x in objs],
),
**kwargs
)
Expand Down Expand Up @@ -486,7 +504,8 @@ def get_operation(self, operation_id, **kwargs):
raise ValueError("'operation_id' cannot be None or empty.")

return ModelOperation._from_generated(
self._client.get_operation(operation_id, **kwargs), api_version=self._api_version
self._client.get_operation(operation_id, **kwargs),
api_version=self._api_version,
)

def get_document_analysis_client(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from azure.core.pipeline.policies import HttpLoggingPolicy
from ._generated._form_recognizer_client import FormRecognizerClient as FormRecognizer
from ._api_versions import validate_api_version
from ._helpers import _get_deserialize, get_authentication_policy, POLLING_INTERVAL, QuotaExceededPolicy
from ._helpers import (
_get_deserialize,
get_authentication_policy,
POLLING_INTERVAL,
QuotaExceededPolicy,
)
from ._user_agent import USER_AGENT

if TYPE_CHECKING:
Expand All @@ -22,9 +27,7 @@ def __init__(self, endpoint, credential, **kwargs):
self._credential = credential
self._api_version = kwargs.pop("api_version", None)
if not self._api_version:
raise ValueError(
"'api_version' must be specified."
)
raise ValueError("'api_version' must be specified.")
if self._api_version.startswith("v"): # v2.0 released with this option
self._api_version = self._api_version[1:]

Expand Down Expand Up @@ -57,7 +60,7 @@ def __init__(self, endpoint, credential, **kwargs):
"pages",
"readingOrder",
"stringIndexType",
"api-version"
"api-version",
}
)

Expand All @@ -66,7 +69,9 @@ def __init__(self, endpoint, credential, **kwargs):
credential=credential, # type: ignore
api_version=self._api_version,
sdk_moniker=USER_AGENT,
authentication_policy=kwargs.get("authentication_policy", authentication_policy),
authentication_policy=kwargs.get(
"authentication_policy", authentication_policy
),
http_logging_policy=kwargs.get("http_logging_policy", http_logging_policy),
per_retry_policies=kwargs.get("per_retry_policies", QuotaExceededPolicy()),
polling_interval=polling_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def __init__(self, endpoint, credential, **kwargs):
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None
api_version = kwargs.pop("api_version", FormRecognizerApiVersion.V2_1)
super(FormRecognizerClient, self).__init__(
endpoint=endpoint, credential=credential, api_version=api_version, client_kind="form", **kwargs
endpoint=endpoint,
credential=credential,
api_version=api_version,
client_kind="form",
**kwargs
)

def _prebuilt_callback(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def __init__(self, endpoint, credential, **kwargs):
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None
api_version = kwargs.pop("api_version", FormRecognizerApiVersion.V2_1)
super(FormTrainingClient, self).__init__(
endpoint=endpoint, credential=credential, api_version=api_version, client_kind="form", **kwargs
endpoint=endpoint,
credential=credential,
api_version=api_version,
client_kind="form",
**kwargs
)

@distributed_trace
Expand Down Expand Up @@ -199,7 +203,9 @@ def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
cls=deserialization_callback,
continuation_token=continuation_token,
polling=LROBasePolling(
timeout=polling_interval, lro_algorithms=[FormTrainingPolling()], **kwargs
timeout=polling_interval,
lro_algorithms=[FormTrainingPolling()],
**kwargs
),
**kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def on_response(self, request, response):
:type response: ~azure.core.pipeline.PipelineResponse
"""
http_response = response.http_response
if http_response.status_code in [403, 429] and \
"Out of call volume quota for FormRecognizer F0 pricing tier" in http_response.text():
if (
http_response.status_code in [403, 429]
and "Out of call volume quota for FormRecognizer F0 pricing tier"
in http_response.text()
):
raise HttpResponseError(http_response.text(), response=http_response)
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
AsyncHttpResponse,
HttpRequest,
)

ResponseType = Union[HttpResponse, AsyncHttpResponse]
PipelineResponseType = PipelineResponse[HttpRequest, ResponseType]



def raise_error(response, errors, message):
error_message = "({}) {}{}".format(errors[0]["code"], errors[0]["message"], message)
error = HttpResponseError(message=error_message, response=response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# ------------------------------------

from ._document_analysis_client_async import DocumentAnalysisClient
from ._document_model_administration_client_async import DocumentModelAdministrationClient
from ._document_model_administration_client_async import (
DocumentModelAdministrationClient,
)
from ._form_recognizer_client_async import FormRecognizerClient
from ._form_training_client_async import FormTrainingClient
from ._async_polling import AsyncDocumentModelAdministrationLROPoller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ def last_updated_on(self):

@classmethod
def from_continuation_token(
cls,
polling_method: AsyncPollingMethod[PollingReturnType],
continuation_token: str,
**kwargs: Any
cls,
polling_method: AsyncPollingMethod[PollingReturnType],
continuation_token: str,
**kwargs: Any
) -> "AsyncDocumentModelAdministrationLROPoller":
(
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from .._api_versions import DocumentAnalysisApiVersion
from ._form_base_client_async import FormRecognizerClientBaseAsync
from .._models import AnalyzeResult

if TYPE_CHECKING:
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential



class DocumentAnalysisClient(FormRecognizerClientBaseAsync):
"""DocumentAnalysisClient analyzes information from documents and images.
It is the interface to use for analyzing with prebuilt models (receipts, business cards,
Expand Down Expand Up @@ -64,9 +64,15 @@ def __init__(
credential: Union["AzureKeyCredential", "AsyncTokenCredential"],
**kwargs: Any
) -> None:
api_version = kwargs.pop("api_version", DocumentAnalysisApiVersion.V2021_09_30_PREVIEW)
api_version = kwargs.pop(
"api_version", DocumentAnalysisApiVersion.V2021_09_30_PREVIEW
)
super(DocumentAnalysisClient, self).__init__(
endpoint=endpoint, credential=credential, api_version=api_version, client_kind="document", **kwargs
endpoint=endpoint,
credential=credential,
api_version=api_version,
client_kind="document",
**kwargs
)

def _analyze_document_callback(
Expand Down
Loading

0 comments on commit d332e47

Please sign in to comment.