Skip to content

Commit

Permalink
Create a base class for clients
Browse files Browse the repository at this point in the history
  • Loading branch information
lidizheng committed Apr 14, 2020
1 parent e8ff641 commit a51bcd3
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 310 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ from google.oauth2 import service_account # type: ignore
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport
from .base_client import {{ service.name }}BaseClient, {{ service.name }}BaseClientMeta


class {{ service.async_client_name }}Meta(type):
class {{ service.async_client_name }}Meta({{ service.name }}BaseClientMeta):
"""Metaclass for the {{ service.name }} client.

This provides class-level methods for building and retrieving
Expand Down Expand Up @@ -55,149 +55,13 @@ class {{ service.async_client_name }}Meta(type):
return next(iter(cls._transport_registry.values()))


class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}Meta):
class {{ service.async_client_name }}({{ service.name }}BaseClient, metaclass={{ service.async_client_name }}Meta):
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""

@staticmethod
def _get_default_mtls_endpoint(api_endpoint):
"""Convert api endpoint to mTLS endpoint.
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
Args:
api_endpoint (Optional[str]): the api endpoint to convert.
Returns:
str: converted mTLS api endpoint.
"""
if not api_endpoint:
return api_endpoint

mtls_endpoint_re = re.compile(
r"(?P<name>[^.]+)(?P<mtls>\.mtls)?(?P<sandbox>\.sandbox)?(?P<googledomain>\.googleapis\.com)?"
)

m = mtls_endpoint_re.match(api_endpoint)
name, mtls, sandbox, googledomain = m.groups()
if mtls or not googledomain:
return api_endpoint

if sandbox:
return api_endpoint.replace(
"sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
)

return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")

DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
DEFAULT_MTLS_TRANSPORT = {{ service.grpc_asyncio_transport_name }}

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
file.

Args:
filename (str): The path to the service account private key json
file.
args: Additional arguments to pass to the constructor.
kwargs: Additional arguments to pass to the constructor.

Returns:
{@api.name}: The constructed client.
"""
credentials = service_account.Credentials.from_service_account_file(
filename)
kwargs['credentials'] = credentials
return cls(*args, **kwargs)

from_service_account_json = from_service_account_file


{% for message in service.resource_messages -%}
@staticmethod
def {{ message.resource_type|snake_case }}_path({% for arg in message.resource_path_args %}{{ arg }}: str,{% endfor %}) -> str:
"""Return a fully-qualified {{ message.resource_type|snake_case }} string."""
return "{{ message.resource_path }}".format({% for arg in message.resource_path_args %}{{ arg }}={{ arg }}, {% endfor %})

{% endfor %}

def __init__(self, *,
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = None,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Args:
credentials (Optional[google.auth.credentials.Credentials]): The
authorization credentials to attach to requests. These
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
transport (Union[str, ~.{{ service.name }}Transport]): The
transport to use. If set to None, a transport is chosen
automatically.
client_options (ClientOptions): Custom options for the client.
(1) The ``api_endpoint`` property can be used to override the
default endpoint provided by the client.
(2) If ``transport`` argument is None, ``client_options`` can be
used to create a mutual TLS transport. If ``client_cert_source``
is provided, mutual TLS transport will be created with the given
``api_endpoint`` or the default mTLS endpoint, and the client
SSL credentials obtained from ``client_cert_source``.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
"""
if isinstance(client_options, dict):
client_options = ClientOptions.from_dict(client_options)

# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
if isinstance(transport, {{ service.name }}Transport):
# transport is a {{ service.name }}Transport instance.
if credentials:
raise ValueError('When providing a transport instance, '
'provide its credentials directly.')
self._transport = transport
elif client_options is None or (
client_options.api_endpoint == None
and client_options.client_cert_source is None
):
# Don't trigger mTLS if we get an empty ClientOptions.
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials, host=self.DEFAULT_ENDPOINT
)
else:
# We have a non-empty ClientOptions. If client_cert_source is
# provided, trigger mTLS with user provided endpoint or the default
# mTLS endpoint.
if client_options.client_cert_source:
api_mtls_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_MTLS_ENDPOINT
)
else:
api_mtls_endpoint = None

api_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_ENDPOINT
)

self._transport = self.DEFAULT_MTLS_TRANSPORT(
credentials=credentials,
host=api_endpoint,
api_mtls_endpoint=api_mtls_endpoint,
client_cert_source=client_options.client_cert_source,
)
def _default_mtls_transport(cls) -> str:
"""Returns the default MTLS transport name."""
return "grpc_asyncio"

{% for method in service.methods.values() -%}
{% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self,
Expand Down Expand Up @@ -312,7 +176,7 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
client_info=self._client_info,
)
{%- if method.field_headers %}

Expand Down Expand Up @@ -367,16 +231,6 @@ class {{ service.async_client_name }}(metaclass={{ service.async_client_name }}M
{% endfor %}


try:
_client_info = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()


__all__ = (
'{{ service.async_client_name }}',
)
Expand Down
Loading

0 comments on commit a51bcd3

Please sign in to comment.