Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Provide AsyncIO support for generated code #365

Merged
merged 4 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,13 @@ def __getattr__(self, name):

@utils.cached_property
def client_output(self):
return self._client_output(enable_asyncio=False)

@utils.cached_property
def client_output_async(self):
return self._client_output(enable_asyncio=True)

def _client_output(self, enable_asyncio: bool):
"""Return the output from the client layer.

This takes into account transformations made by the outer GAPIC
Expand All @@ -584,8 +591,8 @@ def client_output(self):
if self.lro:
return PythonType(meta=metadata.Metadata(
address=metadata.Address(
name='Operation',
module='operation',
name='AsyncOperation' if enable_asyncio else 'Operation',
module='operation_async' if enable_asyncio else 'operation',
package=('google', 'api_core'),
collisions=self.lro.response_type.ident.collisions,
),
Expand All @@ -603,7 +610,7 @@ def client_output(self):
if self.paged_result_field:
return PythonType(meta=metadata.Metadata(
address=metadata.Address(
name=f'{self.name}Pager',
name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager',
package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + (
'services',
utils.to_snake_case(self.ident.parent[-1]),
Expand Down Expand Up @@ -744,6 +751,8 @@ def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]:
if not self.void:
answer.append(self.client_output)
answer.extend(self.client_output.field_types)
answer.append(self.client_output_async)
answer.extend(self.client_output_async.field_types)

# If this method has LRO, it is possible (albeit unlikely) that
# the LRO messages reside in a different module.
Expand Down Expand Up @@ -801,6 +810,11 @@ def client_name(self) -> str:
"""Returns the name of the generated client class"""
return self.name + "Client"

@property
def async_client_name(self) -> str:
"""Returns the name of the generated AsyncIO client class"""
return self.name + "AsyncClient"

@property
def transport_name(self):
return self.name + "Transport"
Expand All @@ -809,6 +823,10 @@ def transport_name(self):
def grpc_transport_name(self):
return self.name + "GrpcTransport"

@property
def grpc_asyncio_transport_name(self):
return self.name + "GrpcAsyncIOTransport"

@property
def has_lro(self) -> bool:
"""Return whether the service has a long-running method."""
Expand Down Expand Up @@ -856,7 +874,7 @@ def names(self) -> FrozenSet[str]:
used for imports.
"""
# Put together a set of the service and method names.
answer = {self.name, self.client_name}
answer = {self.name, self.client_name, self.async_client_name}
answer.update(
utils.to_snake_case(i.name) for i in self.methods.values()
)
Expand Down
3 changes: 3 additions & 0 deletions gapic/templates/%namespace/%name/__init__.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.'
if service.meta.address.subpackage == api.subpackage_view -%}
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client import {{ service.client_name }}
from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%}
{{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client import {{ service.async_client_name }}
{% endfor -%}

{# Import messages and enums from each proto.
Expand Down Expand Up @@ -48,6 +50,7 @@ __all__ = (
{% for service in api.services.values()|sort(attribute='name')
if service.meta.address.subpackage == api.subpackage_view -%}
'{{ service.client_name }}',
'{{ service.async_client_name }}',
{% endfor -%}
{% for proto in api.protos.values()|sort(attribute='module_name')
if proto.meta.address.subpackage == api.subpackage_view -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

{% block content %}
from .client import {{ service.client_name }}
from .async_client import {{ service.async_client_name }}

__all__ = (
'{{ service.client_name }}',
'{{ service.async_client_name }}',
)
{% endblock %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
{% extends '_base.py.j2' %}

{% block content %}
from collections import OrderedDict
import functools
import re
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials # type: ignore
from google.oauth2 import service_account # type: ignore

{% filter sort_lines -%}
{% for method in service.methods.values() -%}
{% for ref_type in method.flat_ref_types -%}
{{ ref_type.ident.python_import }}
{% endfor -%}
{% endfor -%}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
from .client import {{ service.client_name }}


class {{ service.async_client_name }}:
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""

_client: {{ service.client_name }}

DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT
DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT

{% for message in service.resource_messages -%}
{{ message.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.{{ message.resource_type|snake_case }}_path)

{% endfor %}

from_service_account_file = {{ service.client_name }}.from_service_account_file
from_service_account_json = from_service_account_file

get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }}))

def __init__(self, *,
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = 'grpc_asyncio',
client_options: ClientOptions = None,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Args:
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
transport (Union[str, ~.{{ service.name }}Transport]): The
transport to use. If set to None, a transport is chosen
automatically.
client_options (ClientOptions): Custom options for the client. It
won't take effect if a ``transport`` instance is provided.
(1) The ``api_endpoint`` property can be used to override the
default endpoint provided by the client. GOOGLE_API_USE_MTLS
environment variable can also be used to override the endpoint:
"always" (always use the default mTLS endpoint), "never" (always
use the default regular endpoint, this is the default value for
the environment variable) and "auto" (auto switch to the default
mTLS endpoint if client SSL credentials is present). However,
the ``api_endpoint`` property takes precedence if provided.
(2) The ``client_cert_source`` property is used to provide client
SSL credentials for mutual TLS transport. If not provided, the
default SSL credentials will be used if present.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
"""
{# NOTE(lidiz) Not using kwargs since we want the docstring and types. #}
self._client = {{ service.client_name }}(
credentials=credentials,
transport=transport,
client_options=client_options,
)

{% for method in service.methods.values() -%}
{% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self,
{%- if not method.client_streaming %}
request: {{ method.input.ident }} = None,
*,
{% for field in method.flattened_fields.values() -%}
{{ field.name }}: {{ field.ident }} = None,
{% endfor -%}
{%- else %}
requests: AsyncIterator[{{ method.input.ident }}] = None,
*,
{% endif -%}
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
{%- if not method.server_streaming %}
) -> {{ method.client_output_async.ident }}:
{%- else %}
) -> AsyncIterable[{{ method.client_output_async.ident }}]:
{%- endif %}
r"""{{ method.meta.doc|rst(width=72, indent=8) }}

Args:
{%- if not method.client_streaming %}
request (:class:`{{ method.input.ident.sphinx }}`):
The request object.{{ ' ' -}}
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
{% for key, field in method.flattened_fields.items() -%}
{{ field.name }} (:class:`{{ field.ident.sphinx }}`):
{{ field.meta.doc|rst(width=72, indent=16, nl=False) }}
This corresponds to the ``{{ key }}`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
{% endfor -%}
{%- else %}
requests (AsyncIterator[`{{ method.input.ident.sphinx }}`]):
The request object AsyncIterator.{{ ' ' -}}
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }}
{%- endif %}
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
{%- if not method.void %}

Returns:
{%- if not method.server_streaming %}
{{ method.client_output_async.ident.sphinx }}:
{%- else %}
AsyncIterable[{{ method.client_output_async.ident.sphinx }}]:
{%- endif %}
{{ method.client_output_async.meta.doc|rst(width=72, indent=16) }}
{%- endif %}
"""
{%- if not method.client_streaming %}
# Create or coerce a protobuf request object.
{% if method.flattened_fields -%}
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]):
raise ValueError('If the `request` argument is set, then none of '
'the individual field arguments should be set.')

{% endif -%}
{% if method.input.ident.package != method.ident.package -%} {# request lives in a different package, so there is no proto wrapper #}
# The request isn't a proto-plus wrapped type,
# so it must be constructed via keyword expansion.
if isinstance(request, dict):
request = {{ method.input.ident }}(**request)
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
elif not request:
request = {{ method.input.ident }}()
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.{{ method.name|snake_case }},
{%- if method.retry %}
default_retry=retries.Retry(
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
predicate=retries.if_exception_type(
{%- filter sort_lines %}
{%- for ex in method.retry.retryable_exceptions %}
exceptions.{{ ex.__name__ }},
{%- endfor %}
{%- endfilter %}
),
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
)
{%- if method.field_headers %}

# Certain fields should be provided within the metadata header;
# add these here.
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata((
{%- for field_header in method.field_headers %}
{%- if not method.client_streaming %}
('{{ field_header }}', request.{{ field_header }}),
{%- endif %}
{%- endfor %}
)),
)
{%- endif %}

# Send the request.
{% if not method.void %}response = {% endif %}
{%- if not method.server_streaming %}await {% endif %}rpc(
{%- if not method.client_streaming %}
request,
{%- else %}
requests,
{%- endif %}
retry=retry,
timeout=timeout,
metadata=metadata,
)
{%- if method.lro %}

# Wrap the response in an operation future.
response = operation_async.from_gapic(
response,
self._client._transport.operations_client,
{{ method.lro.response_type.ident }},
metadata_type={{ method.lro.metadata_type.ident }},
)
{%- elif method.paged_result_field %}

# This method is paged; wrap the response in a pager, which provides
# an `__aiter__` convenience method.
response = {{ method.client_output_async.ident }}(
method=rpc,
request=request,
response=response,
)
{%- endif %}
{%- if not method.void %}

# Done; return the response.
return response
{%- endif %}
{{ '\n' }}
{% endfor %}


try:
_client_info = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()


__all__ = (
'{{ service.async_client_name }}',
)
{% endblock %}
Loading