Skip to content

Commit

Permalink
Remove the parent class for grpc transports
Browse files Browse the repository at this point in the history
  • Loading branch information
lidizheng committed Apr 17, 2020
1 parent 3cc730c commit e176020
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class {{ service.name }}Transport(abc.ABC):
self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: credentials.Credentials = None,
**kwargs,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -59,7 +60,7 @@ class {{ service.name }}Transport(abc.ABC):
@property
def operations_client(self) -> operations_v1.OperationsClient:
"""Return the client designed to process long-running operations."""
raise NotImplementedError
raise NotImplementedError()
{%- endif %}
{%- for method in service.methods.values() %}

Expand All @@ -70,7 +71,7 @@ class {{ service.name }}Transport(abc.ABC):
{{ method.output.ident }},
typing.Awaitable[{{ method.output.ident }}]
]]:
raise NotImplementedError
raise NotImplementedError()
{%- endfor %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import grpc # type: ignore
{{ method.output.ident.python_import }}
{% endfor -%}
{% endfilter %}
from .grpc_base import {{ service.name }}GrpcBaseTransport
from .base import {{ service.name }}Transport


class {{ service.name }}GrpcTransport({{ service.name }}GrpcBaseTransport[grpc.Channel]):
class {{ service.name }}GrpcTransport({{ service.name }}Transport):
"""gRPC backend transport for {{ service.name }}.

{{ service.meta.doc|rst(width=72, indent=4) }}
Expand All @@ -34,6 +34,71 @@ class {{ service.name }}GrpcTransport({{ service.name }}GrpcBaseTransport[grpc.C
It sends protocol buffers over the wire using gRPC (which is built on
top of HTTP/2); the ``grpcio`` package must be installed.
"""
_stubs: Dict[str, Callable]

def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: credentials.Credentials = None,
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
"""Instantiate the transport.

Args:
host ({% if service.host %}Optional[str]{% else %}str{% endif %}):
{{- ' ' }}The hostname to connect to.
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
This argument is ignored if ``channel`` is provided.
channel (Optional[grpc.Channel]): A ``Channel`` instance through
which to make calls.
api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If
provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
``client_cert_source`` or applicatin default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
"""
if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
elif api_mtls_endpoint:
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
cert, key = client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
ssl_credentials = SslCredentials().ssl_credentials

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
ssl_credentials=ssl_credentials,
scopes=self.AUTH_SCOPES,
)

# Run the base constructor.
super().__init__(host=host, credentials=credentials)
self._stubs = {} # type: Dict[str, Callable]

@classmethod
def create_channel(cls,
Expand Down Expand Up @@ -66,6 +131,23 @@ class {{ service.name }}GrpcTransport({{ service.name }}GrpcBaseTransport[grpc.C
**kwargs
)

@property
def grpc_channel(self) -> grpc.Channel:
"""Create the channel designed to connect to this service.

This property caches on the instance; repeated calls return
the same channel.
"""
# Sanity check: Only create a new channel if we do not already
# have one.
if not hasattr(self, '_grpc_channel'):
self._grpc_channel = self.create_channel(
self._host,
credentials=self._credentials,
)

# Return the channel from cache.
return self._grpc_channel
{%- if service.has_lro %}

@property
Expand Down Expand Up @@ -103,7 +185,17 @@ class {{ service.name }}GrpcTransport({{ service.name }}GrpcBaseTransport[grpc.C
A function that, when called, will call the underlying RPC
on the server.
"""
return super().{{ method.name|snake_case }}
# Generate a "stub function" on-the-fly which will actually make
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
{%- endfor %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
{% block content %}
from typing import Awaitable, Callable, Dict, Sequence, Tuple

from google.api_core import grpc_helpers_async # type: ignore
from google.api_core import grpc_helpers_async # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
from google.api_core import operations_v1 # type: ignore
{%- endif %}
from google.auth import credentials # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore

import grpc # type: ignore
from grpc.experimental import aio # type: ignore

{% filter sort_lines -%}
Expand All @@ -17,10 +19,11 @@ from grpc.experimental import aio # type: ignore
{{ method.output.ident.python_import }}
{% endfor -%}
{% endfilter %}
from .grpc_base import {{ service.name }}GrpcBaseTransport
from .base import {{ service.name }}Transport
from .grpc import {{ service.name }}GrpcTransport


class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransport[aio.Channel]):
class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
"""gRPC AsyncIO backend transport for {{ service.name }}.

{{ service.meta.doc|rst(width=72, indent=4) }}
Expand All @@ -33,6 +36,9 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransp
top of HTTP/2); the ``grpcio`` package must be installed.
"""

_grpc_channel: aio.Channel
_stubs: Dict[str, Callable] = {}

@classmethod
def create_channel(cls,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
Expand Down Expand Up @@ -64,6 +70,87 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransp
**kwargs
)

def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: credentials.Credentials = None,
channel: aio.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
"""Instantiate the transport.

Args:
host ({% if service.host %}Optional[str]{% else %}str{% endif %}):
{{- ' ' }}The hostname to connect to.
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
This argument is ignored if ``channel`` is provided.
channel (Optional[aio.Channel]): A ``Channel`` instance through
which to make calls.
api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If
provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
``client_cert_source`` or applicatin default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
"""
if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
elif api_mtls_endpoint:
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
cert, key = client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
ssl_credentials = SslCredentials().ssl_credentials

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
ssl_credentials=ssl_credentials,
scopes=self.AUTH_SCOPES,
)

# Run the base constructor.
super().__init__(host=host, credentials=credentials)
self._stubs = {}

@property
def grpc_channel(self) -> aio.Channel:
"""Create the channel designed to connect to this service.

This property caches on the instance; repeated calls return
the same channel.
"""
# Sanity check: Only create a new channel if we do not already
# have one.
if not hasattr(self, '_grpc_channel'):
self._grpc_channel = self.create_channel(
self._host,
credentials=self._credentials,
)

# Return the channel from cache.
return self._grpc_channel
{%- if service.has_lro %}

@property
Expand Down Expand Up @@ -101,7 +188,17 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}GrpcBaseTransp
A function that, when called, will call the underlying RPC
on the server.
"""
return super().{{ method.name|snake_case }}
# Generate a "stub function" on-the-fly which will actually make
# the request.
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.name|snake_case }}' not in self._stubs:
self._stubs['{{ method.name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
)
return self._stubs['{{ method.name|snake_case }}']
{%- endfor %}


Expand Down
Loading

0 comments on commit e176020

Please sign in to comment.