Skip to content

Commit

Permalink
Provide AsyncIO support for generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
lidizheng committed Apr 22, 2020
1 parent 920e419 commit 14b1760
Show file tree
Hide file tree
Showing 24 changed files with 1,406 additions and 55 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ jobs:
name: Install system dependencies.
command: |
apt-get update
apt-get install -y curl pandoc unzip
apt-get install -y curl pandoc unzip git
- run:
name: Install nox.
command: pip install nox
Expand Down Expand Up @@ -302,7 +302,7 @@ jobs:
name: Install system dependencies.
command: |
apt-get update
apt-get install -y curl pandoc unzip
apt-get install -y curl pandoc unzip git
- run:
name: Install nox.
command: pip install nox
Expand Down
26 changes: 22 additions & 4 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,13 @@ def __getattr__(self, name):

@utils.cached_property
def client_output(self):
return self._client_output(enable_asyncio=False)

@utils.cached_property
def client_output_async(self):
return self._client_output(enable_asyncio=True)

def _client_output(self, enable_asyncio: bool):
"""Return the output from the client layer.
This takes into account transformations made by the outer GAPIC
Expand All @@ -584,8 +591,8 @@ def client_output(self):
if self.lro:
return PythonType(meta=metadata.Metadata(
address=metadata.Address(
name='Operation',
module='operation',
name='AsyncOperation' if enable_asyncio else 'Operation',
module='operation_async' if enable_asyncio else 'operation',
package=('google', 'api_core'),
collisions=self.lro.response_type.ident.collisions,
),
Expand All @@ -603,7 +610,7 @@ def client_output(self):
if self.paged_result_field:
return PythonType(meta=metadata.Metadata(
address=metadata.Address(
name=f'{self.name}Pager',
name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager',
package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + (
'services',
utils.to_snake_case(self.ident.parent[-1]),
Expand Down Expand Up @@ -734,6 +741,8 @@ def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]:
if not self.void:
answer.append(self.client_output)
answer.extend(self.client_output.field_types)
answer.append(self.client_output_async)
answer.extend(self.client_output_async.field_types)

# If this method has LRO, it is possible (albeit unlikely) that
# the LRO messages reside in a different module.
Expand Down Expand Up @@ -791,6 +800,11 @@ def client_name(self) -> str:
"""Returns the name of the generated client class"""
return self.name + "Client"

@property
def async_client_name(self) -> str:
"""Returns the name of the generated AsyncIO client class"""
return self.name + "AsyncClient"

@property
def transport_name(self):
return self.name + "Transport"
Expand All @@ -799,6 +813,10 @@ def transport_name(self):
def grpc_transport_name(self):
return self.name + "GrpcTransport"

@property
def grpc_asyncio_transport_name(self):
return self.name + "GrpcAsyncIOTransport"

@property
def has_lro(self) -> bool:
"""Return whether the service has a long-running method."""
Expand Down Expand Up @@ -846,7 +864,7 @@ def names(self) -> FrozenSet[str]:
used for imports.
"""
# Put together a set of the service and method names.
answer = {self.name, self.client_name}
answer = {self.name, self.client_name, self.async_client_name}
answer.update(
utils.to_snake_case(i.name) for i in self.methods.values()
)
Expand Down
4 changes: 4 additions & 0 deletions gapic/templates/%namespace/%name/__init__.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ _lazy_name_to_package_map = {
'types': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.types',
{%- for service in api.services.values()|sort(attribute='name')|unique(attribute='name') if service.meta.address.subpackage == api.subpackage_view %}
'{{ service.client_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client',
'{{ service.async_client_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client',
'{{ service.transport_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.transports.base',
'{{ service.grpc_transport_name|snake_case }}': '{% if api.naming.module_namespace %}{{ api.naming.module_namespace|join(".") }}.{% endif -%}{{ api.naming.versioned_module_name }}.services.transports.grpc',
{%- endfor %} {# Need to do types and enums #}
Expand Down Expand Up @@ -105,6 +106,8 @@ from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.'
if service.meta.address.subpackage == api.subpackage_view -%}
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client import {{ service.client_name }}
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client import {{ service.async_client_name }}
{% endfor -%}

{# Import messages and enums from each proto.
Expand Down Expand Up @@ -141,6 +144,7 @@ __all__ = (
{% for service in api.services.values()|sort(attribute='name')
if service.meta.address.subpackage == api.subpackage_view -%}
'{{ service.client_name }}',
'{{ service.async_client_name }}',
{% endfor -%}
{% for proto in api.protos.values()|sort(attribute='module_name')
if proto.meta.address.subpackage == api.subpackage_view -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

{% block content %}
from .client import {{ service.client_name }}
from .async_client import {{ service.async_client_name }}

__all__ = (
'{{ service.client_name }}',
'{{ service.async_client_name }}',
)
{% endblock %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
{% extends '_base.py.j2' %}

{% block content %}
from collections import OrderedDict
import functools
import re
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials # type: ignore
from google.oauth2 import service_account # type: ignore

{% filter sort_lines -%}
{% for method in service.methods.values() -%}
{% for ref_type in method.flat_ref_types -%}
{{ ref_type.ident.python_import }}
{% endfor -%}
{% endfor -%}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
from .client import {{ service.client_name }}


class {{ service.async_client_name }}:
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""

_client: {{ service.client_name }}

DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT
DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT

{% for message in service.resource_messages -%}
{{ message.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.{{ message.resource_type|snake_case }}_path)

{% endfor %}

from_service_account_file = {{ service.client_name }}.from_service_account_file
from_service_account_json = from_service_account_file

get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }}))

def __init__(self, *,
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = "grpc_asyncio",
client_options: ClientOptions = None,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Args:
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
transport (Union[str, ~.{{ service.name }}Transport]): The
transport to use. If set to None, a transport is chosen
automatically.
client_options (ClientOptions): Custom options for the client.
(1) The ``api_endpoint`` property can be used to override the
default endpoint provided by the client.
(2) If ``transport`` argument is None, ``client_options`` can be
used to create a mutual TLS transport. If ``client_cert_source``
is provided, mutual TLS transport will be created with the given
``api_endpoint`` or the default mTLS endpoint, and the client
SSL credentials obtained from ``client_cert_source``.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
"""
{# NOTE(lidiz) Not using kwargs since we want the docstring and types. #}
self._client = {{ service.client_name }}(
credentials=credentials,
transport=transport,
client_options=client_options,
)

{% for method in service.methods.values() -%}
{% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self,
{%- if not method.client_streaming %}
request: {{ method.input.ident }} = None,
*,
{% for field in method.flattened_fields.values() -%}
{{ field.name }}: {{ field.ident }} = None,
{% endfor -%}
{%- else %}
requests: AsyncIterator[{{ method.input.ident }}] = None,
*,
{% endif -%}
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
{%- if not method.server_streaming %}
) -> {{ method.client_output_async.ident }}:
{%- else %}
) -> AsyncIterable[{{ method.client_output_async.ident }}]:
{%- endif %}
r"""{{ method.meta.doc|rst(width=72, indent=8) }}

Args:
{%- if not method.client_streaming %}
request (:class:`{{ method.input.ident.sphinx }}`):
The request object.{{ ' ' -}}
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
{% for key, field in method.flattened_fields.items() -%}
{{ field.name }} (:class:`{{ field.ident.sphinx }}`):
{{ field.meta.doc|rst(width=72, indent=16, nl=False) }}
This corresponds to the ``{{ key }}`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
{% endfor -%}
{%- else %}
requests (AsyncIterator[`{{ method.input.ident.sphinx }}`]):
The request object AsyncIterator.{{ ' ' -}}
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
{%- endif %}
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
{%- if not method.void %}

Returns:
{%- if not method.server_streaming %}
{{ method.client_output_async.ident.sphinx }}:
{%- else %}
AsyncIterable[{{ method.client_output_async.ident.sphinx }}]:
{%- endif %}
{{ method.client_output_async.meta.doc|rst(width=72, indent=16) }}
{%- endif %}
"""
{%- if not method.client_streaming %}
# Create or coerce a protobuf request object.
{% if method.flattened_fields -%}
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]):
raise ValueError('If the `request` argument is set, then none of '
'the individual field arguments should be set.')

{% endif -%}
{% if method.input.ident.package != method.ident.package -%} {# request lives in a different package, so there is no proto wrapper #}
# The request isn't a proto-plus wrapped type,
# so it must be constructed via keyword expansion.
if isinstance(request, dict):
request = {{ method.input.ident }}(**request)
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
elif not request:
request = {{ method.input.ident }}()
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.{{ method.name|snake_case }},
{%- if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
predicate=retries.if_exception_type(
{%- filter sort_lines %}
{%- for ex in method.retry.retryable_exceptions %}
exceptions.{{ ex.__name__ }},
{%- endfor %}
{%- endfilter %}
),
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
)
{%- if method.field_headers %}

# Certain fields should be provided within the metadata header;
# add these here.
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata((
{%- for field_header in method.field_headers %}
('{{ field_header }}', request.{{ field_header }}),
{%- endfor %}
)),
)
{%- endif %}

# Send the request.
{% if not method.void %}response = {% endif %}
{%- if not method.server_streaming %}await {% endif %}rpc(
{%- if not method.client_streaming %}
request,
{%- else %}
requests,
{%- endif %}
retry=retry,
timeout=timeout,
metadata=metadata,
)
{%- if method.lro %}

# Wrap the response in an operation future.
response = operation_async.from_gapic(
response,
self._client._transport.operations_client,
{{ method.lro.response_type.ident }},
metadata_type={{ method.lro.metadata_type.ident }},
)
{%- elif method.paged_result_field %}

# This method is paged; wrap the response in a pager, which provides
# an `__aiter__` convenience method.
response = {{ method.client_output_async.ident }}(
method=rpc,
request=request,
response=response,
)
{%- endif %}
{%- if not method.void %}

# Done; return the response.
return response
{%- endif %}
{{ '\n' }}
{% endfor %}


try:
_client_info = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()


__all__ = (
'{{ service.async_client_name }}',
)
{% endblock %}
Loading

0 comments on commit 14b1760

Please sign in to comment.