diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 6252747109..1dc50daf40 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -593,6 +593,63 @@ def client_output(self): # Return the usual output. return self.output + @utils.cached_property + def client_output_async(self): + """Return the output from the client layer. + + This takes into account transformations made by the outer GAPIC + client to transform the output from the transport. + + Returns: + Union[~.MessageType, ~.PythonType]: + A description of the return type. + """ + # Void messages ultimately return None. + if self.void: + return PrimitiveType.build(None) + + # If this method is an LRO, return a PythonType instance representing + # that. + if self.lro: + return PythonType(meta=metadata.Metadata( + address=metadata.Address( + name='AsyncOperation', + module='operation_async', + package=('google', 'api_core'), + collisions=self.lro.response_type.ident.collisions, + ), + documentation=utils.doc( + 'An object representing a long-running operation. \n\n' + 'The result type for the operation will be ' + ':class:`{ident}`: {doc}'.format( + doc=self.lro.response_type.meta.doc, + ident=self.lro.response_type.ident.sphinx, + ), + ), + )) + + # # If this method is paginated, return that method's pager class. + if self.paged_result_field: + return PythonType(meta=metadata.Metadata( + address=metadata.Address( + name=f'{self.name}AsyncPager', + package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + ( + 'services', + utils.to_snake_case(self.ident.parent[-1]), + ), + module='pagers', + collisions=self.input.ident.collisions, + ), + documentation=utils.doc( + f'{self.output.meta.doc}\n\n' + 'Iterating over this object will yield results and ' + 'resolve additional pages automatically.', + ), + )) + + # Return the usual output. + return self.output + @property def field_headers(self) -> Sequence[str]: """Return the field headers defined for this method.""" @@ -706,6 +763,8 @@ def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]: if not self.void: answer.append(self.client_output) answer.extend(self.client_output.field_types) + answer.append(self.client_output_async) + answer.extend(self.client_output_async.field_types) # If this method has LRO, it is possible (albeit unlikely) that # the LRO messages reside in a different module. @@ -763,6 +822,11 @@ def client_name(self) -> str: """Returns the name of the generated client class""" return self.name + "Client" + @property + def async_client_name(self) -> str: + """Returns the name of the generated AsyncIO client class""" + return self.name + "AsyncClient" + @property def transport_name(self): return self.name + "Transport" @@ -771,6 +835,10 @@ def transport_name(self): def grpc_transport_name(self): return self.name + "GrpcTransport" + @property + def grpc_asyncio_transport_name(self): + return self.name + "GrpcAsyncIOTransport" + @property def has_lro(self) -> bool: """Return whether the service has a long-running method.""" diff --git a/gapic/templates/%namespace/%name/__init__.py.j2 b/gapic/templates/%namespace/%name/__init__.py.j2 index 1ea3128c57..0f94f844ca 100644 --- a/gapic/templates/%namespace/%name/__init__.py.j2 +++ b/gapic/templates/%namespace/%name/__init__.py.j2 @@ -35,6 +35,7 @@ _lazy_name_to_package_map = { 'types': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.types', {%- for service in api.services.values()|sort(attribute='name')|unique(attribute='name') if service.meta.address.subpackage == api.subpackage_view %} '{{ service.client_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client', + '{{ service.async_client_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client', '{{ service.transport_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.transports.base', '{{ service.grpc_transport_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.transports.grpc', {%- endfor %} {# Need to do types and enums #} @@ -105,6 +106,8 @@ from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.' if service.meta.address.subpackage == api.subpackage_view -%} from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%} {{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client import {{ service.client_name }} +from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%} + {{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client import {{ service.async_client_name }} {% endfor -%} {# Import messages and enums from each proto. @@ -141,6 +144,7 @@ __all__ = ( {% for service in api.services.values()|sort(attribute='name') if service.meta.address.subpackage == api.subpackage_view -%} '{{ service.client_name }}', +'{{ service.async_client_name }}', {% endfor -%} {% for proto in api.protos.values()|sort(attribute='module_name') if proto.meta.address.subpackage == api.subpackage_view -%} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 index f9f07d44df..c99b2a5f91 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 @@ -2,8 +2,10 @@ {% block content %} from .client import {{ service.client_name }} +from .async_client import {{ service.async_client_name }} __all__ = ( '{{ service.client_name }}', + '{{ service.async_client_name }}', ) {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 new file mode 100644 index 0000000000..f9887754df --- /dev/null +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 @@ -0,0 +1,309 @@ +{% extends '_base.py.j2' %} + +{% block content %} +from collections import OrderedDict +from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +{% filter sort_lines -%} +{% for method in service.methods.values() -%} +{% for ref_type in method.flat_ref_types -%} +{{ ref_type.ident.python_import }} +{% endfor -%} +{% endfor -%} +{% endfilter %} +from .transports.base import {{ service.name }}Transport +from .transports.grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport + + +class {{ service.async_client_name }}Meta(type): + """Metaclass for the {{ service.name }} client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry: Dict[str, Type[{{ service.name }}Transport]] = OrderedDict() + _transport_registry['grpc_asyncio'] = {{ service.name }}GrpcAsyncIOTransport + + def get_transport_class(cls, + label: str = None, + ) -> Type[{{ service.name }}Transport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}Meta): + """{{ service.meta.doc|rst(width=72, indent=4) }}""" + + DEFAULT_OPTIONS = ClientOptions.ClientOptions({% if service.host %}api_endpoint='{{ service.host }}'{% endif %}) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + {@api.name}: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + + {% for message in service.resource_messages -%} + @staticmethod + def {{ message.resource_type|snake_case }}_path({% for arg in message.resource_path_args %}{{ arg }}: str,{% endfor %}) -> str: + """Return a fully-qualified {{ message.resource_type|snake_case }} string.""" + return "{{ message.resource_path }}".format({% for arg in message.resource_path_args %}{{ arg }}={{ arg }}, {% endfor %}) + + {% endfor %} + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, {{ service.name }}Transport] = None, + client_options: ClientOptions = DEFAULT_OPTIONS, + ) -> None: + """Instantiate the {{ (service.async_client_name|snake_case).replace('_', ' ') }}. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.{{ service.name }}Transport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. + """ + if isinstance(client_options, dict): + client_options = ClientOptions.from_dict(client_options) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, {{ service.name }}Transport): + if credentials: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + host=client_options.api_endpoint{% if service.host %} or '{{ service.host }}'{% endif %}, + ) + + {% for method in service.methods.values() -%} + {% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self, + {%- if not method.client_streaming %} + request: {{ method.input.ident }} = None, + *, + {% for field in method.flattened_fields.values() -%} + {{ field.name }}: {{ field.ident }} = None, + {% endfor -%} + {%- else %} + requests: AsyncIterator[{{ method.input.ident }}] = None, + *, + {% endif -%} + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + {%- if not method.server_streaming %} + ) -> {{ method.client_output_async.ident }}: + {%- else %} + ) -> AsyncIterable[{{ method.client_output_async.ident }}]: + {%- endif %} + r"""{{ method.meta.doc|rst(width=72, indent=8) }} + + Args: + {%- if not method.client_streaming %} + request (:class:`{{ method.input.ident.sphinx }}`): + The request object.{{ ' ' -}} + {{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }} + {% for key, field in method.flattened_fields.items() -%} + {{ field.name }} (:class:`{{ field.ident.sphinx }}`): + {{ field.meta.doc|rst(width=72, indent=16, nl=False) }} + This corresponds to the ``{{ key }}`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + {% endfor -%} + {%- else %} + requests (AsyncIterator[`{{ method.input.ident.sphinx }}`]): + The request object AsyncIterator.{{ ' ' -}} + {{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }} + {%- endif %} + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + {%- if not method.void %} + + Returns: + {%- if not method.server_streaming %} + {{ method.client_output_async.ident.sphinx }}: + {%- else %} + AsyncIterable[{{ method.client_output_async.ident.sphinx }}]: + {%- endif %} + {{ method.client_output_async.meta.doc|rst(width=72, indent=16) }} + {%- endif %} + """ + {%- if not method.client_streaming %} + # Create or coerce a protobuf request object. + {% if method.flattened_fields -%} + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + {% endif -%} + {% if method.input.ident.package != method.ident.package -%} {# request lives in a different package, so there is no proto wrapper #} + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = {{ method.input.ident }}(**request) + {% if method.flattened_fields -%}{# Cross-package req and flattened fields #} + elif not request: + request = {{ method.input.ident }}() + {% endif -%}{# Cross-package req and flattened fields #} + {%- else %} + request = {{ method.input.ident }}(request) + {% endif %} {# different request package #} + + {#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #} + {% if method.flattened_fields -%} + # If we have keyword arguments corresponding to fields on the + # request, apply these. + {% endif -%} + {%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %} + if {{ field.name }} is not None: + request.{{ key }} = {{ field.name }} + {%- endfor %} + {# They can be _extended_, however -#} + {%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %} + if {{ field.name }}: + request.{{ key }}.extend({{ field.name }}) + {%- endfor %} + {%- endif %} + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._transport.{{ method.name|snake_case }}, + {%- if method.retry %} + default_retry=retries.Retry( + {% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %} + {% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %} + {% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %} + predicate=retries.if_exception_type( + {%- filter sort_lines %} + {%- for ex in method.retry.retryable_exceptions %} + exceptions.{{ ex.__name__ }}, + {%- endfor %} + {%- endfilter %} + ), + ), + {%- endif %} + default_timeout={{ method.timeout }}, + client_info=_client_info, + ) + {%- if method.field_headers %} + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + {%- for field_header in method.field_headers %} + ('{{ field_header }}', request.{{ field_header }}), + {%- endfor %} + )), + ) + {%- endif %} + + # Send the request. + {% if not method.void %}response = {% endif %} + {%- if not method.server_streaming %}await {% endif %}rpc( + {%- if not method.client_streaming %} + request, + {%- else %} + requests, + {%- endif %} + retry=retry, + timeout=timeout, + metadata=metadata, + ) + {%- if method.lro %} + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._transport.operations_client, + {{ method.lro.response_type.ident }}, + metadata_type={{ method.lro.metadata_type.ident }}, + ) + {%- elif method.paged_result_field %} + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = {{ method.client_output_async.ident }}( + method=rpc, + request=request, + response=response, + ) + {%- endif %} + {%- if not method.void %} + + # Done; return the response. + return response + {%- endif %} + {{ '\n' }} + {% endfor %} + + +try: + _client_info = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + '{{ api.naming.warehouse_package_name }}', + ).version, + ) +except pkg_resources.DistributionNotFound: + _client_info = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + '{{ service.async_client_name }}', +) +{% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 index 0e7ef018a7..5c069b68fd 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 @@ -6,7 +6,7 @@ {# This lives within the loop in order to ensure that this template is empty if there are no paged methods. -#} -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable {% filter sort_lines -%} {% for method in service.methods.values() | selectattr('paged_result_field') -%} @@ -71,5 +71,64 @@ class {{ method.name }}Pager: def __repr__(self) -> str: return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + +class {{ method.name }}AsyncPager: + """A pager for iterating through ``{{ method.name|snake_case }}`` requests. + + This class thinly wraps an initial + :class:`{{ method.output.ident.sphinx }}` object, and + provides an ``__aiter__`` method to iterate through its + ``{{ method.paged_result_field.name }}`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``{{ method.name }}`` requests and continue to iterate + through the ``{{ method.paged_result_field.name }}`` field on the + corresponding responses. + + All the usual :class:`{{ method.output.ident.sphinx }}` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[[{{ method.input.ident }}], + Awaitable[{{ method.output.ident }}]], + request: {{ method.input.ident }}, + response: {{ method.output.ident }}): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`{{ method.input.ident.sphinx }}`): + The initial request object. + response (:class:`{{ method.output.ident.sphinx }}`): + The initial response object. + """ + self._method = method + self._request = {{ method.input.ident }}(request) + self._response = response + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[{{ method.output.ident }}]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request) + yield self._response + + def __aiter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'AsyncIterable') }}: + async def async_generator(): + async for page in self.pages: + for response in page.{{ method.paged_result_field.name }}: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + {% endfor %} {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 index 470cde5d19..fa97f46164 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 @@ -6,15 +6,18 @@ from typing import Dict, Type from .base import {{ service.name }}Transport from .grpc import {{ service.name }}GrpcTransport +from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[{{ service.name }}Transport]] _transport_registry['grpc'] = {{ service.name }}GrpcTransport +_transport_registry['grpc_asyncio'] = {{ service.name }}GrpcAsyncIOTransport __all__ = ( '{{ service.name }}Transport', '{{ service.name }}GrpcTransport', + '{{ service.name }}GrpcAsyncIOTransport', ) {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 index 694e0a1664..e710a7086e 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2 @@ -66,7 +66,10 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta): @property def {{ method.name|snake_case }}(self) -> typing.Callable[ [{{ method.input.ident }}], - {{ method.output.ident }}]: + typing.Union[ + {{ method.output.ident }}, + typing.Awaitable[{{ method.output.ident }}] + ]]: raise NotImplementedError {%- endfor %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 new file mode 100644 index 0000000000..57e2429924 --- /dev/null +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 @@ -0,0 +1,167 @@ +{% extends '_base.py.j2' %} + +{% block content %} +from typing import Awaitable, Callable, Dict + +from google.api_core import grpc_helpers_async # type: ignore +{%- if service.has_lro %} +from google.api_core import operations_v1 # type: ignore +{%- endif %} +from google.auth import credentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +{% filter sort_lines -%} +{% for method in service.methods.values() -%} +{{ method.input.ident.python_import }} +{{ method.output.ident.python_import }} +{% endfor -%} +{% endfilter %} +from .base import {{ service.name }}Transport + + +class {{ service.name }}GrpcAsyncIOTransport({{ service.name }}Transport): + """gRPC AsyncIO backend transport for {{ service.name }}. + + {{ service.meta.doc|rst(width=72, indent=4) }} + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + _grpc_channel: aio.Channel + + def __init__(self, *, + host: str{% if service.host %} = '{{ service.host }}'{% endif %}, + credentials: credentials.Credentials = None, + channel: aio.Channel = None) -> None: + """Instantiate the transport. + + Args: + host ({% if service.host %}Optional[str]{% else %}str{% endif %}): + {{- ' ' }}The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + """ + # Sanity check: Ensure that channel and credentials are not both + # provided. + if channel: + credentials = False + + # Run the base constructor. + super().__init__(host=host, credentials=credentials) + self._stubs = {} # type: Dict[str, Callable] + + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel + + @classmethod + def create_channel(cls, + host: str{% if service.host %} = '{{ service.host }}'{% endif %}, + credentials: credentials.Credentials = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optionsl[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + scopes=cls.AUTH_SCOPES, + enable_asyncio=True, + **kwargs + ) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, '_grpc_channel'): + self._grpc_channel = self.create_channel( + self._host, + credentials=self._credentials, + ) + + # Return the channel from cache. + return self._grpc_channel + {%- if service.has_lro %} + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + {%- endif %} + {%- for method in service.methods.values() %} + + @property + def {{ method.name|snake_case }}(self) -> Callable[ + [{{ method.input.ident }}], + Awaitable[{{ method.output.ident }}]]: + r"""Return a callable for the {{- ' ' -}} + {{ (method.name|snake_case).replace('_',' ')|wrap( + width=70, offset=40, indent=8) }} + {{- ' ' -}} method over gRPC. + + {{ method.meta.doc|rst(width=72, indent=8) }} + + Returns: + Callable[[~.{{ method.input.name }}], + ~.{{ method.output.name }}]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if '{{ method.name|snake_case }}' not in self._stubs: + self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}( + '/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}', + request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %}, + response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %}, + ) + return self._stubs['{{ method.name|snake_case }}'] + {%- endfor %} + + +__all__ = ( + '{{ service.name }}GrpcAsyncIOTransport', +) +{%- endblock -%} diff --git a/gapic/templates/mypy.ini.j2 b/gapic/templates/mypy.ini.j2 index f23e6b533a..4505b48543 100644 --- a/gapic/templates/mypy.ini.j2 +++ b/gapic/templates/mypy.ini.j2 @@ -1,3 +1,3 @@ [mypy] -python_version = 3.5 +python_version = 3.6 namespace_packages = True diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 5aa782c94c..bd758de24f 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -14,11 +14,27 @@ import collections import pytest +import asyncio -from google.showcase import EchoClient -from google.showcase import IdentityClient +from google.showcase import EchoClient, EchoAsyncClient +from google.showcase import IdentityClient, IdentityAsyncClient import grpc +from grpc.experimental import aio + +import logging +logging.basicConfig(level=logging.DEBUG) + + +# NOTE(lidiz) We must override the default event_loop fixture from +# pytest-asyncio. pytest fixture frees resources once there isn't any reference +# to it. So, the event loop might close before tests finishes. In the +# customized version, we don't close the event loop. +@pytest.fixture +def event_loop(): + loop = asyncio.get_event_loop() + loop.set_debug(True) + return loop @pytest.fixture @@ -29,6 +45,15 @@ def echo(): return EchoClient(transport=transport) +@pytest.fixture +def async_echo(event_loop): + event_loop.set_debug(True) + transport = EchoAsyncClient.get_transport_class('grpc_asyncio')( + channel=aio.insecure_channel('localhost:7469'), + ) + return EchoAsyncClient(transport=transport) + + @pytest.fixture def identity(): transport = IdentityClient.get_transport_class('grpc')( @@ -37,6 +62,14 @@ def identity(): return IdentityClient(transport=transport) +@pytest.fixture +async def async_identity(): + transport = IdentityAsyncClient.get_transport_class('grpc_asyncio')( + channel=aio.insecure_channel('localhost:7469'), + ) + return IdentityAsyncClient(transport=transport) + + class MetadataClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, diff --git a/tests/system/test_grpc_lro.py b/tests/system/test_grpc_lro.py index 617163a0db..a4578a168a 100644 --- a/tests/system/test_grpc_lro.py +++ b/tests/system/test_grpc_lro.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest from datetime import datetime, timedelta, timezone from google import showcase_v1beta1 @@ -27,3 +28,16 @@ def test_lro(echo): response = future.result() assert isinstance(response, showcase_v1beta1.WaitResponse) assert response.content.endswith('the snails...eventually.') + + +@pytest.mark.asyncio +async def test_lro_async(async_echo): + future = await async_echo.wait({ + 'end_time': datetime.now(tz=timezone.utc) + timedelta(seconds=1), + 'success': { + 'content': 'The hail in Wales falls mainly on the snails...eventually.' + }} + ) + response = await future.result() + assert isinstance(response, showcase_v1beta1.WaitResponse) + assert response.content.endswith('the snails...eventually.') diff --git a/tests/system/test_grpc_streams.py b/tests/system/test_grpc_streams.py index d0879d6a89..fb74b3b978 100644 --- a/tests/system/test_grpc_streams.py +++ b/tests/system/test_grpc_streams.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import pytest +import asyncio +import threading from google import showcase @@ -71,3 +75,142 @@ def test_stream_stream_passing_dict(echo): assert contents == ['hello', 'world!'] assert responses.trailing_metadata() == metadata + + +@pytest.mark.asyncio +async def test_async_unary_stream_reader(async_echo): + logging.debug('test_async_unary_stream_reader %s %s', threading.current_thread(), asyncio.get_event_loop()) + content = 'The hail in Wales falls mainly on the snails.' + call = await async_echo.expand({ + 'content': content, + }, metadata=metadata) + logging.debug('test_async_unary_stream_reader 2') + + # Consume the response and ensure it matches what we expect. + # with pytest.raises(exceptions.NotFound) as exc: + for ground_truth in content.split(' '): + logging.debug('test_async_unary_stream_reader 3') + response = await call.read() + assert response.content == ground_truth + logging.debug('test_async_unary_stream_reader 4') + assert ground_truth == 'snails.' + + logging.debug('test_async_unary_stream_reader 5') + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_unary_stream_async_generator(async_echo): + logging.debug('test_async_unary_stream_async_generator') + content = 'The hail in Wales falls mainly on the snails.' + call = await async_echo.expand({ + 'content': content, + }, metadata=metadata) + + # Consume the response and ensure it matches what we expect. + # with pytest.raises(exceptions.NotFound) as exc: + tokens = iter(content.split(' ')) + async for response in call: + ground_truth = next(tokens) + assert response.content == ground_truth + assert ground_truth == 'snails.' + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_stream_unary_iterable(async_echo): + logging.debug('test_async_stream_unary_iterable') + requests = [] + requests.append(showcase.EchoRequest(content="hello")) + requests.append(showcase.EchoRequest(content="world!")) + + call = await async_echo.collect(requests) + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_unary_async_generator(async_echo): + logging.debug('test_async_stream_unary_async_generator') + async def async_generator(): + yield showcase.EchoRequest(content="hello") + yield showcase.EchoRequest(content="world!") + + call = await async_echo.collect(async_generator()) + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_unary_writer(async_echo): + logging.debug('test_async_stream_unary_writer') + call = await async_echo.collect() + await call.write(showcase.EchoRequest(content="hello")) + await call.write(showcase.EchoRequest(content="world!")) + await call.done_writing() + + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_unary_passing_dict(async_echo): + logging.debug('test_async_stream_unary_passing_dict') + requests = [{'content': 'hello'}, {'content': 'world!'}] + call = await async_echo.collect(iter(requests)) + response = await call + assert response.content == 'hello world!' + + +@pytest.mark.asyncio +async def test_async_stream_stream_reader_writier(async_echo): + logging.debug('test_async_stream_stream_reader_writier') + call = await async_echo.chat(metadata=metadata) + await call.write(showcase.EchoRequest(content="hello")) + await call.write(showcase.EchoRequest(content="world!")) + await call.done_writing() + + contents = [ + (await call.read()).content, + (await call.read()).content + ] + assert contents == ['hello', 'world!'] + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_stream_stream_async_generator(async_echo): + logging.debug('test_async_stream_stream_async_generator') + async def async_generator(): + yield showcase.EchoRequest(content="hello") + yield showcase.EchoRequest(content="world!") + + call = await async_echo.chat(async_generator(), metadata=metadata) + + contents = [] + async for response in call: + contents.append(response.content) + assert contents == ['hello', 'world!'] + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata + + +@pytest.mark.asyncio +async def test_async_stream_stream_passing_dict(async_echo): + logging.debug('test_async_stream_stream_passing_dict') + requests = [{'content': 'hello'}, {'content': 'world!'}] + call = await async_echo.chat(iter(requests), metadata=metadata) + + contents = [] + async for response in call: + contents.append(response.content) + assert contents == ['hello', 'world!'] + + trailing_metadata = await call.trailing_metadata() + assert trailing_metadata == metadata diff --git a/tests/system/test_grpc_unary.py b/tests/system/test_grpc_unary.py index f8735a3a31..d0c647d832 100644 --- a/tests/system/test_grpc_unary.py +++ b/tests/system/test_grpc_unary.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest +import asyncio +import logging from google.api_core import exceptions from google.rpc import code_pb2 @@ -45,3 +47,33 @@ def test_unary_error(echo): }) assert exc.value.code == 400 assert exc.value.message == message + + +@pytest.mark.asyncio +async def test_async_unary_with_request_object(async_echo): + response = await async_echo.echo(showcase.EchoRequest( + content='The hail in Wales falls mainly on the snails.', + ), timeout=1) + assert response.content == 'The hail in Wales falls mainly on the snails.' + + +@pytest.mark.asyncio +async def test_async_unary_with_dict(async_echo): + response = await async_echo.echo({ + 'content': 'The hail in Wales falls mainly on the snails.', + }) + assert response.content == 'The hail in Wales falls mainly on the snails.' + + +@pytest.mark.asyncio +async def test_async_unary_error(async_echo): + message = 'Bad things! Bad things!' + with pytest.raises(exceptions.InvalidArgument) as exc: + await async_echo.echo({ + 'error': { + 'code': code_pb2.Code.Value('INVALID_ARGUMENT'), + 'message': message, + }, + }) + assert exc.value.code == 400 + assert exc.value.message == message diff --git a/tests/system/test_pagination.py b/tests/system/test_pagination.py index 781614cad4..8f53a6c01d 100644 --- a/tests/system/test_pagination.py +++ b/tests/system/test_pagination.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest from google import showcase @@ -44,3 +45,41 @@ def test_pagination_pages(echo): results = [r for p in page_results for r in p.responses] assert results == [showcase.EchoResponse(content=i) for i in text.split(' ')] + + +@pytest.mark.asyncio +async def test_pagination_async(async_echo): + text = 'The hail in Wales falls mainly on the snails.' + results = [] + async for i in await async_echo.paged_expand({ + 'content': text, + 'page_size': 3, + }): + results.append(i) + + assert len(results) == 9 + assert results == [showcase.EchoResponse(content=i) + for i in text.split(' ')] + + +@pytest.mark.asyncio +async def test_pagination_pages_async(async_echo): + text = "The hail in Wales falls mainly on the snails." + page_results = [] + async for page in (await async_echo.paged_expand({ + 'content': text, + 'page_size': 3, + })).pages: + page_results.append(page) + + assert len(page_results) == 3 + assert not page_results[-1].next_page_token + + # The monolithic surface uses a wrapper type that needs an explicit property + # for a 'raw_page': we need to duplicate that interface, even though the + # architecture is different. + assert page_results[0].raw_page is page_results[0] + + results = [r for p in page_results for r in p.responses] + assert results == [showcase.EchoResponse(content=i) + for i in text.split(' ')] diff --git a/tests/system/test_resource_crud.py b/tests/system/test_resource_crud.py index 597b936cca..ae790deba7 100644 --- a/tests/system/test_resource_crud.py +++ b/tests/system/test_resource_crud.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + def test_crud_with_request(identity): count = len(identity.list_users().users) @@ -50,3 +52,45 @@ def test_path_methods(identity): actual = identity.user_path("bdfl") assert expected == actual + + +@pytest.mark.asyncio +async def test_crud_with_request_async(async_identity): + pager = await async_identity.list_users() + count = len(pager.users) + user = await async_identity.create_user(request={'user': { + 'display_name': 'Guido van Rossum', + 'email': 'guido@guido.fake', + }}) + try: + assert user.display_name == 'Guido van Rossum' + assert user.email == 'guido@guido.fake' + pager = (await async_identity.list_users()) + assert len(pager.users) == count + 1 + assert (await async_identity.get_user({ + 'name': user.name + })).display_name == 'Guido van Rossum' + finally: + await async_identity.delete_user({'name': user.name}) + + +@pytest.mark.asyncio +async def test_crud_flattened_async(async_identity): + count = len((await async_identity.list_users()).users) + user = await async_identity.create_user( + display_name='Monty Python', + email='monty@python.org', + ) + try: + assert user.display_name == 'Monty Python' + assert user.email == 'monty@python.org' + assert len((await async_identity.list_users()).users) == count + 1 + assert (await async_identity.get_user(name=user.name)).display_name == 'Monty Python' + finally: + await async_identity.delete_user(name=user.name) + + +def test_path_methods_async(async_identity): + expected = "users/bdfl" + actual = async_identity.user_path("bdfl") + assert expected == actual diff --git a/tests/system/test_retry.py b/tests/system/test_retry.py index bf02842949..0bc70f9f8e 100644 --- a/tests/system/test_retry.py +++ b/tests/system/test_retry.py @@ -26,3 +26,14 @@ def test_retry_bubble(echo): 'message': 'This took longer than you said it should.', }, }) + + +@pytest.mark.asyncio +async def test_retry_bubble_async(async_echo): + with pytest.raises(exceptions.DeadlineExceeded): + await async_echo.echo({ + 'error': { + 'code': code_pb2.Code.Value('DEADLINE_EXCEEDED'), + 'message': 'This took longer than you said it should.', + }, + })