diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index f2c0d5effc..5a632fcc02 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -566,6 +566,13 @@ def __getattr__(self, name): @utils.cached_property def client_output(self): + return self._client_output(enable_asyncio=False) + + @utils.cached_property + def client_output_async(self): + return self._client_output(enable_asyncio=True) + + def _client_output(self, enable_asyncio: bool): """Return the output from the client layer. This takes into account transformations made by the outer GAPIC @@ -584,8 +591,8 @@ def client_output(self): if self.lro: return PythonType(meta=metadata.Metadata( address=metadata.Address( - name='Operation', - module='operation', + name='AsyncOperation' if enable_asyncio else 'Operation', + module='operation_async' if enable_asyncio else 'operation', package=('google', 'api_core'), collisions=self.lro.response_type.ident.collisions, ), @@ -603,7 +610,7 @@ def client_output(self): if self.paged_result_field: return PythonType(meta=metadata.Metadata( address=metadata.Address( - name=f'{self.name}Pager', + name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager', package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + ( 'services', utils.to_snake_case(self.ident.parent[-1]), @@ -744,6 +751,8 @@ def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]: if not self.void: answer.append(self.client_output) answer.extend(self.client_output.field_types) + answer.append(self.client_output_async) + answer.extend(self.client_output_async.field_types) # If this method has LRO, it is possible (albeit unlikely) that # the LRO messages reside in a different module. @@ -801,6 +810,11 @@ def client_name(self) -> str: """Returns the name of the generated client class""" return self.name + "Client" + @property + def async_client_name(self) -> str: + """Returns the name of the generated AsyncIO client class""" + return self.name + "AsyncClient" + @property def transport_name(self): return self.name + "Transport" @@ -809,6 +823,10 @@ def transport_name(self): def grpc_transport_name(self): return self.name + "GrpcTransport" + @property + def grpc_asyncio_transport_name(self): + return self.name + "GrpcAsyncIOTransport" + @property def has_lro(self) -> bool: """Return whether the service has a long-running method.""" @@ -856,7 +874,7 @@ def names(self) -> FrozenSet[str]: used for imports. """ # Put together a set of the service and method names. - answer = {self.name, self.client_name} + answer = {self.name, self.client_name, self.async_client_name} answer.update( utils.to_snake_case(i.name) for i in self.methods.values() ) diff --git a/gapic/templates/%namespace/%name/__init__.py.j2 b/gapic/templates/%namespace/%name/__init__.py.j2 index 15f4a17e44..d777dc86e3 100644 --- a/gapic/templates/%namespace/%name/__init__.py.j2 +++ b/gapic/templates/%namespace/%name/__init__.py.j2 @@ -12,6 +12,8 @@ from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.' if service.meta.address.subpackage == api.subpackage_view -%} from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%} {{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client import {{ service.client_name }} +from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%} + {{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client import {{ service.async_client_name }} {% endfor -%} {# Import messages and enums from each proto. @@ -48,6 +50,7 @@ __all__ = ( {% for service in api.services.values()|sort(attribute='name') if service.meta.address.subpackage == api.subpackage_view -%} '{{ service.client_name }}', +'{{ service.async_client_name }}', {% endfor -%} {% for proto in api.protos.values()|sort(attribute='module_name') if proto.meta.address.subpackage == api.subpackage_view -%} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 index f9f07d44df..c99b2a5f91 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 @@ -2,8 +2,10 @@ {% block content %} from .client import {{ service.client_name }} +from .async_client import {{ service.async_client_name }} __all__ = ( '{{ service.client_name }}', + '{{ service.async_client_name }}', ) {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 new file mode 100644 index 0000000000..fb501fe2bc --- /dev/null +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 @@ -0,0 +1,271 @@ +{% extends '_base.py.j2' %} + +{% block content %} +from collections import OrderedDict +import functools +import re +from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +{% filter sort_lines -%} +{% for method in service.methods.values() -%} +{% for ref_type in method.flat_ref_types -%} +{{ ref_type.ident.python_import }} +{% endfor -%} +{% endfor -%} +{% endfilter %} +from .transports.base import {{ service.name }}Transport +from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }} +from .client import {{ service.client_name }} + + +class {{ service.async_client_name }}: + """{{ service.meta.doc|rst(width=72, indent=4) }}""" + + _client: {{ service.client_name }} + + DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT + + {% for message in service.resource_messages -%} + {{ message.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.{{ message.resource_type|snake_case }}_path) + + {% endfor %} + + from_service_account_file = {{ service.client_name }}.from_service_account_file + from_service_account_json = from_service_account_file + + get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }})) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, {{ service.name }}Transport] = 'grpc_asyncio', + client_options: ClientOptions = None, + ) -> None: + """Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.{{ service.name }}Transport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint, this is the default value for + the environment variable) and "auto" (auto switch to the default + mTLS endpoint if client SSL credentials is present). However, + the ``api_endpoint`` property takes precedence if provided. + (2) The ``client_cert_source`` property is used to provide client + SSL credentials for mutual TLS transport. If not provided, the + default SSL credentials will be used if present. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + {# NOTE(lidiz) Not using kwargs since we want the docstring and types. #} + self._client = {{ service.client_name }}( + credentials=credentials, + transport=transport, + client_options=client_options, + ) + + {% for method in service.methods.values() -%} + {% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self, + {%- if not method.client_streaming %} + request: {{ method.input.ident }} = None, + *, + {% for field in method.flattened_fields.values() -%} + {{ field.name }}: {{ field.ident }} = None, + {% endfor -%} + {%- else %} + requests: AsyncIterator[{{ method.input.ident }}] = None, + *, + {% endif -%} + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + {%- if not method.server_streaming %} + ) -> {{ method.client_output_async.ident }}: + {%- else %} + ) -> AsyncIterable[{{ method.client_output_async.ident }}]: + {%- endif %} + r"""{{ method.meta.doc|rst(width=72, indent=8) }} + + Args: + {%- if not method.client_streaming %} + request (:class:`{{ method.input.ident.sphinx }}`): + The request object.{{ ' ' -}} + {{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }} + {% for key, field in method.flattened_fields.items() -%} + {{ field.name }} (:class:`{{ field.ident.sphinx }}`): + {{ field.meta.doc|rst(width=72, indent=16, nl=False) }} + This corresponds to the ``{{ key }}`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + {% endfor -%} + {%- else %} + requests (AsyncIterator[`{{ method.input.ident.sphinx }}`]): + The request object AsyncIterator.{{ ' ' -}} + {{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }} + {%- endif %} + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + {%- if not method.void %} + + Returns: + {%- if not method.server_streaming %} + {{ method.client_output_async.ident.sphinx }}: + {%- else %} + AsyncIterable[{{ method.client_output_async.ident.sphinx }}]: + {%- endif %} + {{ method.client_output_async.meta.doc|rst(width=72, indent=16) }} + {%- endif %} + """ + {%- if not method.client_streaming %} + # Create or coerce a protobuf request object. + {% if method.flattened_fields -%} + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + {% endif -%} + {% if method.input.ident.package != method.ident.package -%} {# request lives in a different package, so there is no proto wrapper #} + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = {{ method.input.ident }}(**request) + {% if method.flattened_fields -%}{# Cross-package req and flattened fields #} + elif not request: + request = {{ method.input.ident }}() + {% endif -%}{# Cross-package req and flattened fields #} + {%- else %} + request = {{ method.input.ident }}(request) + {% endif %} {# different request package #} + + {#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #} + {% if method.flattened_fields -%} + # If we have keyword arguments corresponding to fields on the + # request, apply these. + {% endif -%} + {%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %} + if {{ field.name }} is not None: + request.{{ key }} = {{ field.name }} + {%- endfor %} + {# They can be _extended_, however -#} + {%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %} + if {{ field.name }}: + request.{{ key }}.extend({{ field.name }}) + {%- endfor %} + {%- endif %} + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.{{ method.name|snake_case }}, + {%- if method.retry %} + default_retry=retries.Retry( + {% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %} + {% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %} + {% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %} + predicate=retries.if_exception_type( + {%- filter sort_lines %} + {%- for ex in method.retry.retryable_exceptions %} + exceptions.{{ ex.__name__ }}, + {%- endfor %} + {%- endfilter %} + ), + ), + {%- endif %} + default_timeout={{ method.timeout }}, + client_info=_client_info, + ) + {%- if method.field_headers %} + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + {%- for field_header in method.field_headers %} + {%- if not method.client_streaming %} + ('{{ field_header }}', request.{{ field_header }}), + {%- endif %} + {%- endfor %} + )), + ) + {%- endif %} + + # Send the request. + {% if not method.void %}response = {% endif %} + {%- if not method.server_streaming %}await {% endif %}rpc( + {%- if not method.client_streaming %} + request, + {%- else %} + requests, + {%- endif %} + retry=retry, + timeout=timeout, + metadata=metadata, + ) + {%- if method.lro %} + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + {{ method.lro.response_type.ident }}, + metadata_type={{ method.lro.metadata_type.ident }}, + ) + {%- elif method.paged_result_field %} + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = {{ method.client_output_async.ident }}( + method=rpc, + request=request, + response=response, + ) + {%- endif %} + {%- if not method.void %} + + # Done; return the response. + return response + {%- endif %} + {{ '\n' }} + {% endfor %} + + +try: + _client_info = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + '{{ api.naming.warehouse_package_name }}', + ).version, + ) +except pkg_resources.DistributionNotFound: + _client_info = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + '{{ service.async_client_name }}', +) +{% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 6e25957f63..48efc9de0b 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -24,7 +24,8 @@ from google.oauth2 import service_account # type: ignore {% endfor -%} {% endfilter %} from .transports.base import {{ service.name }}Transport -from .transports.grpc import {{ service.name }}GrpcTransport +from .transports.grpc import {{ service.grpc_transport_name }} +from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }} class {{ service.client_name }}Meta(type): @@ -35,11 +36,12 @@ class {{ service.client_name }}Meta(type): objects. """ _transport_registry = OrderedDict() # type: Dict[str, Type[{{ service.name }}Transport]] - _transport_registry['grpc'] = {{ service.name }}GrpcTransport + _transport_registry['grpc'] = {{ service.grpc_transport_name }} + _transport_registry['grpc_asyncio'] = {{ service.grpc_asyncio_transport_name }} def get_transport_class(cls, label: str = None, - ) -> Type[{{ service.name }}Transport]: + ) -> Type[{{ service.name }}Transport]: """Return an appropriate transport class. Args: @@ -148,7 +150,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): transport to use. If set to None, a transport is chosen automatically. client_options (ClientOptions): Custom options for the client. It - won't take effect unless ``transport`` is None. + won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS environment variable can also be used to override the endpoint: @@ -170,7 +172,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): if client_options is None: client_options = ClientOptions.ClientOptions() - if transport is None and client_options.api_endpoint is None: + if client_options.api_endpoint is None: use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") if use_mtls_env == "never": client_options.api_endpoint = self.DEFAULT_ENDPOINT @@ -198,13 +200,9 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): raise ValueError('When providing a transport instance, ' 'provide its credentials directly.') self._transport = transport - elif isinstance(transport, str): + else: Transport = type(self).get_transport_class(transport) self._transport = Transport( - credentials=credentials, host=self.DEFAULT_ENDPOINT - ) - else: - self._transport = {{ service.name }}GrpcTransport( credentials=credentials, host=client_options.api_endpoint, api_mtls_endpoint=client_options.api_endpoint, diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 index 0e7ef018a7..5c069b68fd 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 @@ -6,7 +6,7 @@ {# This lives within the loop in order to ensure that this template is empty if there are no paged methods. -#} -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable {% filter sort_lines -%} {% for method in service.methods.values() | selectattr('paged_result_field') -%} @@ -71,5 +71,64 @@ class {{ method.name }}Pager: def __repr__(self) -> str: return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + +class {{ method.name }}AsyncPager: + """A pager for iterating through ``{{ method.name|snake_case }}`` requests. + + This class thinly wraps an initial + :class:`{{ method.output.ident.sphinx }}` object, and + provides an ``__aiter__`` method to iterate through its + ``{{ method.paged_result_field.name }}`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``{{ method.name }}`` requests and continue to iterate + through the ``{{ method.paged_result_field.name }}`` field on the + corresponding responses. + + All the usual :class:`{{ method.output.ident.sphinx }}` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[[{{ method.input.ident }}], + Awaitable[{{ method.output.ident }}]], + request: {{ method.input.ident }}, + response: {{ method.output.ident }}): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`{{ method.input.ident.sphinx }}`): + The initial request object. + response (:class:`{{ method.output.ident.sphinx }}`): + The initial response object. + """ + self._method = method + self._request = {{ method.input.ident }}(request) + self._response = response + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[{{ method.output.ident }}]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request) + yield self._response + + def __aiter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'AsyncIterable') }}: + async def async_generator(): + async for page in self.pages: + for response in page.{{ method.paged_result_field.name }}: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + {% endfor %} {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 index 470cde5d19..fa97f46164 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 @@ -6,15 +6,18 @@ from typing import Dict, Type from .base import {{ service.name }}Transport from .grpc import {{ service.name }}GrpcTransport +from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[{{ service.name }}Transport]] _transport_registry['grpc'] = {{ service.name }}GrpcTransport +_transport_registry['grpc_asyncio'] = {{ service.name }}GrpcAsyncIOTransport __all__ = ( '{{ service.name }}Transport', '{{ service.name }}GrpcTransport', + '{{ service.name }}GrpcAsyncIOTransport', ) {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 index 694e0a1664..6eaf999459 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 @@ -17,7 +17,7 @@ from google.auth import credentials # type: ignore {% endfor -%} {% endfilter %} -class {{ service.name }}Transport(metaclass=abc.ABCMeta): +class {{ service.name }}Transport(abc.ABC): """Abstract transport class for {{ service.name }}.""" AUTH_SCOPES = ( @@ -30,6 +30,7 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta): self, *, host: str{% if service.host %} = '{{ service.host }}'{% endif %}, credentials: credentials.Credentials = None, + **kwargs, ) -> None: """Instantiate the transport. @@ -59,15 +60,18 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta): @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() {%- endif %} {%- for method in service.methods.values() %} @property def {{ method.name|snake_case }}(self) -> typing.Callable[ [{{ method.input.ident }}], - {{ method.output.ident }}]: - raise NotImplementedError + typing.Union[ + {{ method.output.ident }}, + typing.Awaitable[{{ method.output.ident }}] + ]]: + raise NotImplementedError() {%- endfor %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index 1632b77621..7288972b8c 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -1,7 +1,7 @@ {% extends '_base.py.j2' %} {% block content %} -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore {%- if service.has_lro %} @@ -35,6 +35,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__(self, *, host: str{% if service.host %} = '{{ service.host }}'{% endif %}, credentials: credentials.Credentials = None, @@ -64,8 +66,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): is None. Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. """ if channel: # Sanity check: Ensure that channel and credentials are not both @@ -91,7 +93,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): ssl_credentials = SslCredentials().ssl_credentials # create a new channel. The provided one is ignored. - self._grpc_channel = grpc_helpers.create_channel( + self._grpc_channel = type(self).create_channel( host, credentials=credentials, ssl_credentials=ssl_credentials, @@ -102,11 +104,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - @classmethod def create_channel(cls, host: str{% if service.host %} = '{{ service.host }}'{% endif %}, credentials: credentials.Credentials = None, + scopes: Optional[Sequence[str]] = None, **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -116,15 +118,19 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( host, credentials=credentials, - scopes=cls.AUTH_SCOPES, + scopes=scopes, **kwargs ) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 new file mode 100644 index 0000000000..53fd1c7188 --- /dev/null +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 @@ -0,0 +1,207 @@ +{% extends '_base.py.j2' %} + +{% block content %} +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers_async # type: ignore +{%- if service.has_lro %} +from google.api_core import operations_v1 # type: ignore +{%- endif %} +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +{% filter sort_lines -%} +{% for method in service.methods.values() -%} +{{ method.input.ident.python_import }} +{{ method.output.ident.python_import }} +{% endfor -%} +{% endfilter %} +from .base import {{ service.name }}Transport +from .grpc import {{ service.name }}GrpcTransport + + +class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): + """gRPC AsyncIO backend transport for {{ service.name }}. + + {{ service.meta.doc|rst(width=72, indent=4) }} + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str{% if service.host %} = '{{ service.host }}'{% endif %}, + credentials: credentials.Credentials = None, + scopes: Optional[Sequence[str]] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + scopes=scopes, + **kwargs + ) + + def __init__(self, *, + host: str{% if service.host %} = '{{ service.host }}'{% endif %}, + credentials: credentials.Credentials = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None: + """Instantiate the transport. + + Args: + host ({% if service.host %}Optional[str]{% else %}str{% endif %}): + {{- ' ' }}The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If + provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A + callback to provide client SSL certificate bytes and private key + bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` + is None. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + ssl_credentials=ssl_credentials, + scopes=self.AUTH_SCOPES, + ) + + # Run the base constructor. + super().__init__(host=host, credentials=credentials) + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, '_grpc_channel'): + self._grpc_channel = self.create_channel( + self._host, + credentials=self._credentials, + ) + + # Return the channel from cache. + return self._grpc_channel + {%- if service.has_lro %} + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + {%- endif %} + {%- for method in service.methods.values() %} + + @property + def {{ method.name|snake_case }}(self) -> Callable[ + [{{ method.input.ident }}], + Awaitable[{{ method.output.ident }}]]: + r"""Return a callable for the {{- ' ' -}} + {{ (method.name|snake_case).replace('_',' ')|wrap( + width=70, offset=40, indent=8) }} + {{- ' ' -}} method over gRPC. + + {{ method.meta.doc|rst(width=72, indent=8) }} + + Returns: + Callable[[~.{{ method.input.name }}], + Awaitable[~.{{ method.output.name }}]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if '{{ method.name|snake_case }}' not in self._stubs: + self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}( + '/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}', + request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %}, + response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %}, + ) + return self._stubs['{{ method.name|snake_case }}'] + {%- endfor %} + + +__all__ = ( + '{{ service.name }}GrpcAsyncIOTransport', +) +{%- endblock -%} diff --git a/gapic/templates/noxfile.py.j2 b/gapic/templates/noxfile.py.j2 index 71f99a4144..d31a325e2f 100644 --- a/gapic/templates/noxfile.py.j2 +++ b/gapic/templates/noxfile.py.j2 @@ -10,7 +10,7 @@ import nox # type: ignore def unit(session): """Run the unit test suite.""" - session.install('coverage', 'pytest', 'pytest-cov') + session.install('coverage', 'pytest', 'pytest-cov', 'asyncmock', 'pytest-asyncio') session.install('-e', '.') session.run( diff --git a/gapic/templates/setup.py.j2 b/gapic/templates/setup.py.j2 index e400754b11..ccfb661812 100644 --- a/gapic/templates/setup.py.j2 +++ b/gapic/templates/setup.py.j2 @@ -16,7 +16,7 @@ setuptools.setup( platforms='Posix; MacOS X; Windows', include_package_data=True, install_requires=( - 'google-api-core[grpc] >= 1.17.0, < 2.0.0dev', + 'google-api-core[grpc] >= 1.17.2, < 2.0.0dev', 'libcst >= 0.2.5', 'proto-plus >= 0.4.0', {%- if api.requires_package(('google', 'iam', 'v1')) %} diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 79d420da82..550ca1bac4 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -2,9 +2,10 @@ {% block content %} import os -from unittest import mock +import mock import grpc +from grpc.experimental import aio import math import pytest @@ -15,9 +16,11 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.oauth2 import service_account from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }} +from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.async_client_name }} from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports from google.api_core import client_options from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async {% if service.has_lro -%} from google.api_core import future from google.api_core import operations_v1 @@ -52,14 +55,15 @@ def test__get_default_mtls_endpoint(): assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi -def test_{{ service.client_name|snake_case }}_from_service_account_file(): +@pytest.mark.parametrize("client_class", [{{ service.client_name }}, {{ service.async_client_name }}]) +def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds - client = {{ service.client_name }}.from_service_account_file("dummy/file/path.json") + client = client_class.from_service_account_file("dummy/file/path.json") assert client._transport._credentials == creds - client = {{ service.client_name }}.from_service_account_json("dummy/file/path.json") + client = client_class.from_service_account_json("dummy/file/path.json") assert client._transport._credentials == creds {% if service.host %}assert client._transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %} @@ -73,26 +77,30 @@ def test_{{ service.client_name|snake_case }}_get_transport_class(): assert transport == transports.{{ service.name }}GrpcTransport -def test_{{ service.client_name|snake_case }}_client_options(): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"), + ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio") +]) +def test_{{ service.client_name|snake_case }}_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: - transport = transports.{{ service.name }}GrpcTransport( + with mock.patch.object({{ service.client_name }}, 'get_transport_class') as gtc: + transport = transport_class( credentials=credentials.AnonymousCredentials() ) - client = {{ service.client_name }}(transport=transport) + client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: - client = {{ service.client_name }}(transport="grpc") + with mock.patch.object({{ service.client_name }}, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: - grpc_transport.return_value = None - client = {{ service.client_name }}(client_options=options) - grpc_transport.assert_called_once_with( + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( api_mtls_endpoint="squid.clam.whelk", client_cert_source=None, credentials=None, @@ -102,10 +110,10 @@ def test_{{ service.client_name|snake_case }}_client_options(): # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is # "never". os.environ["GOOGLE_API_USE_MTLS"] = "never" - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: - grpc_transport.return_value = None - client = {{ service.client_name }}() - grpc_transport.assert_called_once_with( + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( api_mtls_endpoint=client.DEFAULT_ENDPOINT, client_cert_source=None, credentials=None, @@ -115,10 +123,10 @@ def test_{{ service.client_name|snake_case }}_client_options(): # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is # "always". os.environ["GOOGLE_API_USE_MTLS"] = "always" - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: - grpc_transport.return_value = None - client = {{ service.client_name }}() - grpc_transport.assert_called_once_with( + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, client_cert_source=None, credentials=None, @@ -129,10 +137,10 @@ def test_{{ service.client_name|snake_case }}_client_options(): # "auto", and client_cert_source is provided. os.environ["GOOGLE_API_USE_MTLS"] = "auto" options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: - grpc_transport.return_value = None - client = {{ service.client_name }}(client_options=options) - grpc_transport.assert_called_once_with( + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, client_cert_source=client_cert_source_callback, credentials=None, @@ -142,11 +150,11 @@ def test_{{ service.client_name|snake_case }}_client_options(): # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is # "auto", and default_client_cert_source is provided. os.environ["GOOGLE_API_USE_MTLS"] = "auto" - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + with mock.patch.object(transport_class, '__init__') as patched: with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - grpc_transport.return_value = None - client = {{ service.client_name }}() - grpc_transport.assert_called_once_with( + patched.return_value = None + client = client_class() + patched.assert_called_once_with( api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, client_cert_source=None, credentials=None, @@ -156,11 +164,11 @@ def test_{{ service.client_name|snake_case }}_client_options(): # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is # "auto", but client_cert_source and default_client_cert_source are None. os.environ["GOOGLE_API_USE_MTLS"] = "auto" - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + with mock.patch.object(transport_class, '__init__') as patched: with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): - grpc_transport.return_value = None - client = {{ service.client_name }}() - grpc_transport.assert_called_once_with( + patched.return_value = None + client = client_class() + patched.assert_called_once_with( api_mtls_endpoint=client.DEFAULT_ENDPOINT, client_cert_source=None, credentials=None, @@ -171,7 +179,7 @@ def test_{{ service.client_name|snake_case }}_client_options(): # unsupported value. os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported" with pytest.raises(MutualTLSChannelError): - client = {{ service.client_name }}() + client = client_class() del os.environ["GOOGLE_API_USE_MTLS"] @@ -258,6 +266,89 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'): {% endfor %} {% endif %} + +@pytest.mark.asyncio +async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio'): + client = {{ service.async_client_name }}( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = {{ method.input.ident }}() + {% if method.client_streaming %} + requests = [request] + {% endif %} + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.{{ method.name|snake_case }}), + '__call__') as call: + # Designate an appropriate return value for the call. + {% if method.void -%} + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + {% elif method.lro -%} + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + {% elif not method.client_streaming and method.server_streaming -%} + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock(side_effect=[{{ method.output.ident }}()]) + {% elif method.client_streaming and method.server_streaming -%} + call.return_value = mock.Mock(aio.StreamStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock(side_effect=[{{ method.output.ident }}()]) + {% else -%} + call.return_value ={{' '}} + {%- if not method.client_streaming and not method.server_streaming -%} + grpc_helpers_async.FakeUnaryUnaryCall + {%- else -%} + grpc_helpers_async.FakeStreamUnaryCall + {%- endif -%}({{ method.output.ident }}( + {%- for field in method.output.fields.values() | rejectattr('message') %} + {{ field.name }}={{ field.mock_value }}, + {%- endfor %} + )) + {% endif -%} + {% if method.client_streaming and method.server_streaming %} + response = await client.{{ method.name|snake_case }}(iter(requests)) + {% elif method.client_streaming and not method.server_streaming %} + response = await (await client.{{ method.name|snake_case }}(iter(requests))) + {% else %} + response = await client.{{ method.name|snake_case }}(request) + {% endif %} + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + {% if method.client_streaming %} + assert next(args[0]) == request + {% else %} + assert args[0] == request + {% endif %} + + # Establish that the response is the type that we expect. + {% if method.void -%} + assert response is None + {% elif method.lro -%} + assert isinstance(response, future.Future) + {% elif method.server_streaming -%} + message = await response.read() + assert isinstance(message, {{ method.output.ident }}) + {% else -%} + assert isinstance(response, {{ method.client_output_async.ident }}) + {% for field in method.output.fields.values() | rejectattr('message') -%} + {% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#} + assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6) + {% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #} + assert response.{{ field.name }} is {{ field.mock_value }} + {% else -%} + assert response.{{ field.name }} == {{ field.mock_value }} + {% endif -%} + {% endfor %} + {% endif %} + + {% if method.field_headers and not method.client_streaming %} def test_{{ method.name|snake_case }}_field_headers(): client = {{ service.client_name }}( @@ -301,6 +392,52 @@ def test_{{ method.name|snake_case }}_field_headers(): {%- if not loop.last %}&{% endif -%} {%- endfor %}', ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_{{ method.name|snake_case }}_field_headers_async(): + client = {{ service.async_client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = {{ method.input.ident }}() + + {%- for field_header in method.field_headers %} + request.{{ field_header }} = '{{ field_header }}/value' + {%- endfor %} + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.{{ method.name|snake_case }}), + '__call__') as call: + {% if method.void -%} + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + {% elif method.lro -%} + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + {% elif method.server_streaming -%} + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock(side_effect=[{{ method.output.ident }}()]) + {% else -%} + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall({{ method.output.ident }}()) + {% endif %} + await client.{{ method.name|snake_case }}(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + '{% for field_header in method.field_headers -%} + {{ field_header }}={{ field_header }}/value + {%- if not loop.last %}&{% endif -%} + {%- endfor %}', + ) in kw['metadata'] {% endif %} {% if method.ident.package != method.input.ident.package %} @@ -383,6 +520,80 @@ def test_{{ method.name|snake_case }}_flattened_error(): {{ field.name }}={{ field.mock_value }}, {%- endfor %} ) + + +@pytest.mark.asyncio +async def test_{{ method.name|snake_case }}_flattened_async(): + client = {{ service.async_client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.{{ method.name|snake_case }}), + '__call__') as call: + # Designate an appropriate return value for the call. + {% if method.void -%} + call.return_value = None + {% elif method.lro -%} + call.return_value = operations_pb2.Operation(name='operations/op') + {% elif method.server_streaming -%} + call.return_value = iter([{{ method.output.ident }}()]) + {% else -%} + call.return_value = {{ method.output.ident }}() + {% endif %} + + + {% if method.void -%} + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + {% elif method.lro -%} + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + {% elif not method.client_streaming and method.server_streaming -%} + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + {% elif method.client_streaming and method.server_streaming -%} + call.return_value = mock.Mock(aio.StreamStreamCall, autospec=True) + {% else -%} + call.return_value ={{' '}} + {%- if not method.client_streaming and not method.server_streaming -%} + grpc_helpers_async.FakeUnaryUnaryCall + {%- else -%} + grpc_helpers_async.FakeStreamUnaryCall + {%- endif -%}({{ method.output.ident }}()) + {% endif -%} + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.{{ method.name|snake_case }}( + {%- for field in method.flattened_fields.values() %} + {{ field.name }}={{ field.mock_value }}, + {%- endfor %} + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + {% for key, field in method.flattened_fields.items() -%} + assert args[0].{{ key }} == {{ field.mock_value }} + {% endfor %} + + +@pytest.mark.asyncio +async def test_{{ method.name|snake_case }}_flattened_error_async(): + client = {{ service.async_client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.{{ method.name|snake_case }}( + {{ method.input.ident }}(), + {%- for field in method.flattened_fields.values() %} + {{ field.name }}={{ field.mock_value }}, + {%- endfor %} + ) {% endif %} @@ -471,6 +682,98 @@ def test_{{ method.name|snake_case }}_pages(): pages = list(client.{{ method.name|snake_case }}(request={}).pages) for page, token in zip(pages, ['abc','def','ghi', '']): assert page.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_{{ method.name|snake_case }}_async_pager(): + client = {{ service.async_client_name }}( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.{{ method.name|snake_case }}), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.message.ident }}(), + {{ method.paged_result_field.message.ident }}(), + {{ method.paged_result_field.message.ident }}(), + ], + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[], + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.message.ident }}(), + ], + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.message.ident }}(), + {{ method.paged_result_field.message.ident }}(), + ], + ), + RuntimeError, + ) + async_pager = await client.{{ method.name|snake_case }}(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, {{ method.paged_result_field.message.ident }}) + for i in responses) + +@pytest.mark.asyncio +async def test_{{ method.name|snake_case }}_async_pages(): + client = {{ service.async_client_name }}( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.{{ method.name|snake_case }}), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.message.ident }}(), + {{ method.paged_result_field.message.ident }}(), + {{ method.paged_result_field.message.ident }}(), + ], + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[], + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.message.ident }}(), + ], + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.message.ident }}(), + {{ method.paged_result_field.message.ident }}(), + ], + ), + RuntimeError, + ) + pages = [] + async for page in (await client.{{ method.name|snake_case }}(request={})).pages: + pages.append(page) + for page, token in zip(pages, ['abc','def','ghi', '']): + assert page.raw_page.next_page_token == token {% elif method.lro and "next_page_token" in method.lro.response_type.fields.keys() %} def test_{{ method.name|snake_case }}_raw_page_lro(): response = {{ method.lro.response_type.ident }}() @@ -500,6 +803,21 @@ def test_transport_instance(): assert client._transport is transport +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.{{ service.name }}GrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.{{ service.grpc_asyncio_transport_name }}( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = {{ service.client_name }}( @@ -598,6 +916,23 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel(): assert not callback.called +def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that if channel is provided, mtls endpoint and client_cert_source + # won't be used. + callback = mock.MagicMock() + transport = transports.{{ service.name }}GrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=callback, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert not callback.called + + @mock.patch("grpc.ssl_channel_credentials", autospec=True) @mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_client_cert_source( @@ -635,6 +970,43 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_client_c assert transport.grpc_channel == mock_grpc_channel +@mock.patch("grpc.ssl_channel_credentials", autospec=True) +@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) +def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel_mtls_with_client_cert_source( + grpc_create_channel, grpc_ssl_channel_cred +): + # Check that if channel is None, but api_mtls_endpoint and client_cert_source + # are provided, then a mTLS channel will be created. + mock_cred = mock.Mock() + + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + transport = transports.{{ service.name }}GrpcAsyncIOTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=( + {%- for scope in service.oauth_scopes %} + '{{ scope }}', + {%- endfor %} + ), + ) + assert transport.grpc_channel == mock_grpc_channel + + @pytest.mark.parametrize( "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] ) @@ -674,6 +1046,45 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc( assert transport.grpc_channel == mock_grpc_channel +@pytest.mark.parametrize( + "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] +) +@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) +def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel_mtls_with_adc( + grpc_create_channel, api_mtls_endpoint +): + # Check that if channel and client_cert_source are None, but api_mtls_endpoint + # is provided, then a mTLS channel will be created with SSL ADC. + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + # Mock google.auth.transport.grpc.SslCredentials class. + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + mock_cred = mock.Mock() + transport = transports.{{ service.name }}GrpcAsyncIOTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint=api_mtls_endpoint, + client_cert_source=None, + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=( + {%- for scope in service.oauth_scopes %} + '{{ scope }}', + {%- endfor %} + ), + ) + assert transport.grpc_channel == mock_grpc_channel + + {% if service.has_lro -%} def test_{{ service.name|snake_case }}_grpc_lro_client(): client = {{ service.client_name }}( @@ -691,6 +1102,23 @@ def test_{{ service.name|snake_case }}_grpc_lro_client(): # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + +def test_{{ service.name|snake_case }}_grpc_lro_async_client(): + client = {{ service.async_client_name }}( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client._client._transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + {% endif -%} {% for message in service.resource_messages -%} diff --git a/noxfile.py b/noxfile.py index 83d38f6042..6b4c3f14e6 100644 --- a/noxfile.py +++ b/noxfile.py @@ -63,6 +63,7 @@ def showcase( # Install pytest and gapic-generator-python session.install("mock") session.install("pytest") + session.install("pytest-asyncio") session.install("-e", ".") # Install a client library for Showcase. @@ -121,6 +122,7 @@ def showcase_mtls( # Install pytest and gapic-generator-python session.install("mock") session.install("pytest") + session.install("pytest-asyncio") session.install("-e", ".") # Install a client library for Showcase. @@ -182,15 +184,13 @@ def showcase_unit( ): """Run the generated unit tests against the Showcase library.""" - # Install pytest and gapic-generator-python session.install( - "coverage", "pytest", "pytest-cov", "pytest-xdist", + "coverage", "pytest", "pytest-cov", "pytest-xdist", 'asyncmock', 'pytest-asyncio' ) session.install(".") # Install a client library for Showcase. with tempfile.TemporaryDirectory() as tmp_dir: - # Download the Showcase descriptor. session.run( "curl", diff --git a/tests/system/conftest.py b/tests/system/conftest.py index ed21d45d54..b549eeb6b9 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -16,15 +16,26 @@ import mock import os import pytest +import asyncio import google.api_core.client_options as ClientOptions -from google import showcase from google.auth import credentials -from google.showcase import EchoClient -from google.showcase import IdentityClient +from google.showcase import EchoClient, EchoAsyncClient +from google.showcase import IdentityClient, IdentityAsyncClient from google.showcase import MessagingClient import grpc +from grpc.experimental import aio + + +# NOTE(lidiz) We must override the default event_loop fixture from +# pytest-asyncio. pytest fixture frees resources once there isn't any reference +# to it. So, the event loop might close before tests finishes. In the +# customized version, we don't close the event loop. +@pytest.fixture +def event_loop(): + loop = asyncio.get_event_loop() + return loop dir = os.path.dirname(__file__) @@ -52,7 +63,10 @@ def pytest_addoption(parser): ) -def construct_client(client_class, use_mtls): +def construct_client(client_class, + use_mtls, + transport="grpc", + channel_creator=grpc.insecure_channel): if use_mtls: with mock.patch("grpc.ssl_channel_credentials", autospec=True) as mock_ssl_cred: mock_ssl_cred.return_value = ssl_credentials @@ -65,8 +79,8 @@ def construct_client(client_class, use_mtls): ) return client else: - transport = client_class.get_transport_class("grpc")( - channel=grpc.insecure_channel("localhost:7469") + transport = client_class.get_transport_class(transport)( + channel=channel_creator("localhost:7469") ) return client_class(transport=transport) @@ -81,6 +95,34 @@ def echo(use_mtls): return construct_client(EchoClient, use_mtls) +@pytest.fixture +def async_echo(use_mtls, event_loop): + return construct_client( + EchoAsyncClient, + use_mtls, + transport="grpc_asyncio", + channel_creator=aio.insecure_channel + ) + + +@pytest.fixture +def identity(): + transport = IdentityClient.get_transport_class('grpc')( + channel=grpc.insecure_channel('localhost:7469'), + ) + return IdentityClient(transport=transport) + + +@pytest.fixture +def async_identity(use_mtls, event_loop): + return construct_client( + IdentityAsyncClient, + use_mtls, + transport="grpc_asyncio", + channel_creator=aio.insecure_channel + ) + + @pytest.fixture def identity(use_mtls): return construct_client(IdentityClient, use_mtls) diff --git a/tests/system/test_grpc_lro.py b/tests/system/test_grpc_lro.py index 617163a0db..a4578a168a 100644 --- a/tests/system/test_grpc_lro.py +++ b/tests/system/test_grpc_lro.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest from datetime import datetime, timedelta, timezone from google import showcase_v1beta1 @@ -27,3 +28,16 @@ def test_lro(echo): response = future.result() assert isinstance(response, showcase_v1beta1.WaitResponse) assert response.content.endswith('the snails...eventually.') + + +@pytest.mark.asyncio +async def test_lro_async(async_echo): + future = await async_echo.wait({ + 'end_time': datetime.now(tz=timezone.utc) + timedelta(seconds=1), + 'success': { + 'content': 'The hail in Wales falls mainly on the snails...eventually.' + }} + ) + response = await future.result() + assert isinstance(response, showcase_v1beta1.WaitResponse) + assert response.content.endswith('the snails...eventually.') diff --git a/tests/system/test_grpc_streams.py b/tests/system/test_grpc_streams.py index d0879d6a89..f77e819986 100644 --- a/tests/system/test_grpc_streams.py +++ b/tests/system/test_grpc_streams.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import pytest +import asyncio +import threading from google import showcase @@ -71,3 +75,131 @@ def test_stream_stream_passing_dict(echo): assert contents == ['hello', 'world!'] assert responses.trailing_metadata() == metadata + + +@pytest.mark.asyncio +async def test_async_unary_stream_reader(async_echo): + content = 'The hail in Wales falls mainly on the snails.' + call = await async_echo.expand({ + 'content': content, + }, metadata=metadata) + + # Consume the response and ensure it matches what we expect. + # with pytest.raises(exceptions.NotFound) as exc: + for ground_truth in content.split(' '): + response = await call.read() + assert response.content == ground_truth + assert ground_truth == 'snails.' + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_unary_stream_async_generator(async_echo): + content = 'The hail in Wales falls mainly on the snails.' + call = await async_echo.expand({ + 'content': content, + }, metadata=metadata) + + # Consume the response and ensure it matches what we expect. + # with pytest.raises(exceptions.NotFound) as exc: + tokens = iter(content.split(' ')) + async for response in call: + ground_truth = next(tokens) + assert response.content == ground_truth + assert ground_truth == 'snails.' + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_stream_unary_iterable(async_echo): + requests = [] + requests.append(showcase.EchoRequest(content="hello")) + requests.append(showcase.EchoRequest(content="world!")) + + call = await async_echo.collect(requests) + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_unary_async_generator(async_echo): + + async def async_generator(): + yield showcase.EchoRequest(content="hello") + yield showcase.EchoRequest(content="world!") + + call = await async_echo.collect(async_generator()) + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_unary_writer(async_echo): + call = await async_echo.collect() + await call.write(showcase.EchoRequest(content="hello")) + await call.write(showcase.EchoRequest(content="world!")) + await call.done_writing() + + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_unary_passing_dict(async_echo): + requests = [{'content': 'hello'}, {'content': 'world!'}] + call = await async_echo.collect(iter(requests)) + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_stream_reader_writier(async_echo): + call = await async_echo.chat(metadata=metadata) + await call.write(showcase.EchoRequest(content="hello")) + await call.write(showcase.EchoRequest(content="world!")) + await call.done_writing() + + contents = [ + (await call.read()).content, + (await call.read()).content + ] + assert contents == ['hello', 'world!'] + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_stream_stream_async_generator(async_echo): + + async def async_generator(): + yield showcase.EchoRequest(content="hello") + yield showcase.EchoRequest(content="world!") + + call = await async_echo.chat(async_generator(), metadata=metadata) + + contents = [] + async for response in call: + contents.append(response.content) + assert contents == ['hello', 'world!'] + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_stream_stream_passing_dict(async_echo): + requests = [{'content': 'hello'}, {'content': 'world!'}] + call = await async_echo.chat(iter(requests), metadata=metadata) + + contents = [] + async for response in call: + contents.append(response.content) + assert contents == ['hello', 'world!'] + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata diff --git a/tests/system/test_grpc_unary.py b/tests/system/test_grpc_unary.py index f8735a3a31..c1694d975e 100644 --- a/tests/system/test_grpc_unary.py +++ b/tests/system/test_grpc_unary.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import asyncio from google.api_core import exceptions from google.rpc import code_pb2 @@ -45,3 +46,33 @@ def test_unary_error(echo): }) assert exc.value.code == 400 assert exc.value.message == message + + +@pytest.mark.asyncio +async def test_async_unary_with_request_object(async_echo): + response = await async_echo.echo(showcase.EchoRequest( + content='The hail in Wales falls mainly on the snails.', + ), timeout=1) + assert response.content == 'The hail in Wales falls mainly on the snails.' + + +@pytest.mark.asyncio +async def test_async_unary_with_dict(async_echo): + response = await async_echo.echo({ + 'content': 'The hail in Wales falls mainly on the snails.', + }) + assert response.content == 'The hail in Wales falls mainly on the snails.' + + +@pytest.mark.asyncio +async def test_async_unary_error(async_echo): + message = 'Bad things! Bad things!' + with pytest.raises(exceptions.InvalidArgument) as exc: + await async_echo.echo({ + 'error': { + 'code': code_pb2.Code.Value('INVALID_ARGUMENT'), + 'message': message, + }, + }) + assert exc.value.code == 400 + assert exc.value.message == message diff --git a/tests/system/test_pagination.py b/tests/system/test_pagination.py index 781614cad4..8f53a6c01d 100644 --- a/tests/system/test_pagination.py +++ b/tests/system/test_pagination.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest from google import showcase @@ -44,3 +45,41 @@ def test_pagination_pages(echo): results = [r for p in page_results for r in p.responses] assert results == [showcase.EchoResponse(content=i) for i in text.split(' ')] + + +@pytest.mark.asyncio +async def test_pagination_async(async_echo): + text = 'The hail in Wales falls mainly on the snails.' + results = [] + async for i in await async_echo.paged_expand({ + 'content': text, + 'page_size': 3, + }): + results.append(i) + + assert len(results) == 9 + assert results == [showcase.EchoResponse(content=i) + for i in text.split(' ')] + + +@pytest.mark.asyncio +async def test_pagination_pages_async(async_echo): + text = "The hail in Wales falls mainly on the snails." + page_results = [] + async for page in (await async_echo.paged_expand({ + 'content': text, + 'page_size': 3, + })).pages: + page_results.append(page) + + assert len(page_results) == 3 + assert not page_results[-1].next_page_token + + # The monolithic surface uses a wrapper type that needs an explicit property + # for a 'raw_page': we need to duplicate that interface, even though the + # architecture is different. + assert page_results[0].raw_page is page_results[0] + + results = [r for p in page_results for r in p.responses] + assert results == [showcase.EchoResponse(content=i) + for i in text.split(' ')] diff --git a/tests/system/test_resource_crud.py b/tests/system/test_resource_crud.py index 7d32c37e6f..5372da4b6b 100644 --- a/tests/system/test_resource_crud.py +++ b/tests/system/test_resource_crud.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + def test_crud_with_request(identity): count = len(identity.list_users().users) @@ -72,3 +74,45 @@ def test_path_parsing(messaging): messaging.blurb_path("bdfl", "apocalyptic", "city") ) assert expected == actual + + +@pytest.mark.asyncio +async def test_crud_with_request_async(async_identity): + pager = await async_identity.list_users() + count = len(pager.users) + user = await async_identity.create_user(request={'user': { + 'display_name': 'Guido van Rossum', + 'email': 'guido@guido.fake', + }}) + try: + assert user.display_name == 'Guido van Rossum' + assert user.email == 'guido@guido.fake' + pager = (await async_identity.list_users()) + assert len(pager.users) == count + 1 + assert (await async_identity.get_user({ + 'name': user.name + })).display_name == 'Guido van Rossum' + finally: + await async_identity.delete_user({'name': user.name}) + + +@pytest.mark.asyncio +async def test_crud_flattened_async(async_identity): + count = len((await async_identity.list_users()).users) + user = await async_identity.create_user( + display_name='Monty Python', + email='monty@python.org', + ) + try: + assert user.display_name == 'Monty Python' + assert user.email == 'monty@python.org' + assert len((await async_identity.list_users()).users) == count + 1 + assert (await async_identity.get_user(name=user.name)).display_name == 'Monty Python' + finally: + await async_identity.delete_user(name=user.name) + + +def test_path_methods_async(async_identity): + expected = "users/bdfl" + actual = async_identity.user_path("bdfl") + assert expected == actual diff --git a/tests/system/test_retry.py b/tests/system/test_retry.py index bf02842949..0bc70f9f8e 100644 --- a/tests/system/test_retry.py +++ b/tests/system/test_retry.py @@ -26,3 +26,14 @@ def test_retry_bubble(echo): 'message': 'This took longer than you said it should.', }, }) + + +@pytest.mark.asyncio +async def test_retry_bubble_async(async_echo): + with pytest.raises(exceptions.DeadlineExceeded): + await async_echo.echo({ + 'error': { + 'code': code_pb2.Code.Value('DEADLINE_EXCEEDED'), + 'message': 'This took longer than you said it should.', + }, + }) diff --git a/tests/unit/schema/test_api.py b/tests/unit/schema/test_api.py index b519b0353a..8dc1760cd8 100644 --- a/tests/unit/schema/test_api.py +++ b/tests/unit/schema/test_api.py @@ -364,6 +364,7 @@ def test_proto_names_import_collision_flattening(): module='squid', ), imp.Import(package=('google', 'api_core'), module='operation',), + imp.Import(package=('google', 'api_core'), module='operation_async',), } assert expected_imports == actual_imports diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index f1ec092a07..c0102402c2 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -83,6 +83,12 @@ def test_method_client_output_paged(): assert method.client_output.ident.name == 'ListFoosPager' +def test_method_client_output_async_empty(): + empty = make_message(name='Empty', package='google.protobuf') + method = make_method('Meh', output_message=empty) + assert method.client_output_async == wrappers.PrimitiveType.build(None) + + def test_method_paged_result_field_not_first(): paged = make_field(name='foos', message=make_message('Foo'), repeated=True) input_msg = make_message(name='ListFoosRequest', fields=( @@ -148,6 +154,7 @@ def test_method_paged_result_ref_types(): assert ref_type_names == { 'ListSquidsRequest', 'ListSquidsPager', + 'ListSquidsAsyncPager', 'Mollusc', } diff --git a/tests/unit/schema/wrappers/test_service.py b/tests/unit/schema/wrappers/test_service.py index e78a477b77..86d1aa2e97 100644 --- a/tests/unit/schema/wrappers/test_service.py +++ b/tests/unit/schema/wrappers/test_service.py @@ -36,8 +36,10 @@ def test_service_properties(): service = make_service(name='ThingDoer') assert service.name == 'ThingDoer' assert service.client_name == 'ThingDoerClient' + assert service.async_client_name == 'ThingDoerAsyncClient' assert service.transport_name == 'ThingDoerTransport' assert service.grpc_transport_name == 'ThingDoerGrpcTransport' + assert service.grpc_asyncio_transport_name == 'ThingDoerGrpcAsyncIOTransport' def test_service_host(): @@ -62,7 +64,7 @@ def test_service_names(): get_method('Jump', 'foo.bacon.JumpRequest', 'foo.bacon.JumpResponse'), get_method('Yawn', 'a.b.v1.c.YawnRequest', 'x.y.v1.z.YawnResponse'), )) - expected_names = {'ThingDoer', 'ThingDoerClient', + expected_names = {'ThingDoer', 'ThingDoerClient', 'ThingDoerAsyncClient', 'do_thing', 'jump', 'yawn'} assert service.names == expected_names @@ -73,7 +75,7 @@ def test_service_name_colliding_modules(): get_method('Jump', 'bacon.bar.JumpRequest', 'bacon.bar.JumpResponse'), get_method('Yawn', 'a.b.v1.c.YawnRequest', 'a.b.v1.c.YawnResponse'), )) - expected_names = {'ThingDoer', 'ThingDoerClient', + expected_names = {'ThingDoer', 'ThingDoerClient', 'ThingDoerAsyncClient', 'do_thing', 'jump', 'yawn', 'bar'} assert service.names == expected_names @@ -112,6 +114,7 @@ def test_service_python_modules_lro(): imp.Import(package=('foo',), module='baz'), imp.Import(package=('foo',), module='qux'), imp.Import(package=('google', 'api_core'), module='operation'), + imp.Import(package=('google', 'api_core'), module='operation_async'), } @@ -138,6 +141,7 @@ def test_service_python_modules_signature(): imp.Import(package=('foo',), module='baz'), imp.Import(package=('foo',), module='qux'), imp.Import(package=('google', 'api_core'), module='operation'), + imp.Import(package=('google', 'api_core'), module='operation_async'), }