From c4d79cfd2c349f3b79d78bef97de8f010e74526b Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Thu, 2 Apr 2020 23:36:38 -0700 Subject: [PATCH 1/9] Add mTLS feature --- .../%sub/services/%service/client.py.j2 | 76 ++++- .../services/%service/transports/grpc.py.j2 | 54 +++- gapic/templates/setup.py.j2 | 1 + .../%name_%version/%sub/test_%service.py.j2 | 261 ++++++++++++++---- 4 files changed, 323 insertions(+), 69 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 1f821b8af9..fb9a383f11 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -2,7 +2,7 @@ {% block content %} from collections import OrderedDict -from typing import Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union +from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union import pkg_resources import google.api_core.client_options as ClientOptions # type: ignore @@ -23,6 +23,37 @@ from .transports.base import {{ service.name }}Transport from .transports.grpc import {{ service.name }}GrpcTransport +def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if ( + api_endpoint is None + or api_endpoint.find("mtls.sandbox.googleapis.com") != -1 + or api_endpoint.find("mtls.googleapis.com") != -1 + or api_endpoint.find(".googleapis.com") == -1 + ): + # If the endpoint is already mTLS or the endpoint is not a googleapi, + # there is no need to convert. + return api_endpoint + + if api_endpoint.find(".sandbox.googleapis.com") != -1: + return api_endpoint.replace( + ".sandbox.googleapis.com", ".mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + +_DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else }None{% endif %} +_DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint(_DEFAULT_ENDPOINT) + + class {{ service.client_name }}Meta(type): """Metaclass for the {{ service.name }} client. @@ -57,7 +88,7 @@ class {{ service.client_name }}Meta(type): class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): """{{ service.meta.doc|rst(width=72, indent=4) }}""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions({% if service.host %}api_endpoint='{{ service.host }}'{% endif %}) + DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=_DEFAULT_ENDPOINT) @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): @@ -106,23 +137,60 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): transport to use. If set to None, a transport is chosen automatically. client_options (ClientOptions): Custom options for the client. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. + (2) If ``transport`` argument is None, ``client_options`` can be + used to create a mutual TLS transport. If ``api_endpoint`` is + provided and different from the default endpoint, or the + ``client_cert_source`` property is provided, mutual TLS + transport will be created if client SSL credentials are found. + Client SSL credentials are obtained from ``client_cert_source`` + or application default SSL credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): client_options = ClientOptions.from_dict(client_options) + # Set default api endpoint if not set. + if client_options.api_endpoint is None: + client_options.api_endpoint = _DEFAULT_ENDPOINT + # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, {{ service.name }}Transport): + # transport is a {{ service.name }}Transport instance. if credentials: raise ValueError('When providing a transport instance, ' 'provide its credentials directly.') self._transport = transport + elif transport is not None: + # transport is a string representing the name. + self._transport = Transport( + credentials=credentials, host=client_options.api_endpoint + ) else: - Transport = type(self).get_transport_class(transport) + # transport is None, we will create a transport instance. + # Figure out if mTLS channel should be created. + api_mtls_endpoint = None + + # If the user overrides endpoint, use it as the mTLS endpoint. If the + # user doesn't override endpoint, but provides client_cert_source, + # use the default mTLS endpoint. + if client_options.api_endpoint != _DEFAULT_ENDPOINT: + api_mtls_endpoint = client_options.api_endpoint + elif client_options.client_cert_source: + api_mtls_endpoint = _DEFAULT_MTLS_ENDPOINT + + Transport = type(self).get_transport_class() self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint{% if service.host %} or '{{ service.host }}'{% endif %}, + host=client_options.api_endpoint, + api_mtls_endpoint=api_mtls_endpoint, + client_cert_source=client_options.client_cert_source, ) {% for method in service.methods.values() -%} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index bd3b074e74..8e87d77d72 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -1,13 +1,15 @@ {% extends '_base.py.j2' %} {% block content %} -from typing import Callable, Dict +from typing import Callable, Dict, Tuple from google.api_core import grpc_helpers # type: ignore {%- if service.has_lro %} from google.api_core import operations_v1 # type: ignore {%- endif %} from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + import grpc # type: ignore @@ -35,7 +37,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): def __init__(self, *, host: str{% if service.host %} = '{{ service.host }}'{% endif %}, credentials: credentials.Credentials = None, - channel: grpc.Channel = None) -> None: + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_callback: Callable[[], Tuple[bytes, bytes]] = None) -> None: """Instantiate the transport. Args: @@ -49,19 +53,55 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): This argument is ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If + provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A + callback to provide client SSL certificate bytes and private key + bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` + is None. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + host = ( + (":" in api_mtls_endpoint) + and api_mtls_endpoint + or (api_mtls_endpoint + ":443") + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = grpc_helpers.create_channel( + host, + credentials=credentials, + ssl_credentials=ssl_credentials, + scopes=self.AUTH_SCOPES, + ) + # Run the base constructor. super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel @classmethod def create_channel(cls, diff --git a/gapic/templates/setup.py.j2 b/gapic/templates/setup.py.j2 index 9d408a23dd..91bff5f82b 100644 --- a/gapic/templates/setup.py.j2 +++ b/gapic/templates/setup.py.j2 @@ -16,6 +16,7 @@ setuptools.setup( platforms='Posix; MacOS X; Windows', include_package_data=True, install_requires=( + 'google-auth >= 1.13.1', 'google-api-core >= 1.8.0, < 2.0.0dev', 'googleapis-common-protos >= 1.5.8', 'grpcio >= 1.10.0', diff --git a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 index 3916e795d9..62e4737c8d 100644 --- a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 @@ -14,7 +14,11 @@ from google.auth import credentials from google.oauth2 import service_account from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }} from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports +from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _get_default_mtls_endpoint +from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _DEFAULT_ENDPOINT +from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _DEFAULT_MTLS_ENDPOINT from google.api_core import client_options +from google.api_core import grpc_helpers {% if service.has_lro -%} from google.api_core import future from google.api_core import operations_v1 @@ -30,6 +34,25 @@ from google.longrunning import operations_pb2 {% endfilter %} +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert _get_default_mtls_endpoint(None) == None + assert _get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert _get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert _get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert _get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert _get_default_mtls_endpoint(non_googleapi) == non_googleapi + + def test_{{ service.client_name|snake_case }}_from_service_account_file(): creds = credentials.AnonymousCredentials() with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: @@ -45,17 +68,54 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file(): def test_{{ service.client_name|snake_case }}_client_options(): # Check the default options have their expected values. - {% if service.host %}assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == '{{ service.host }}'{% endif %} + assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %} + assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == _DEFAULT_ENDPOINT + + # Check that the given channel is used. + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: + transport = transports.{{ service.name }}GrpcTransport( + credentials=credentials.AnonymousCredentials() + ) + client = {{ service.client_name }}(transport=transport) + gtc.assert_not_called() + + # Check mTLS is not triggered with empty client options. + options = client_options.ClientOptions() + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: + transport = gtc.return_value = mock.MagicMock() + client = {{ service.client_name }}(client_options=options) + transport.assert_called_once_with( + api_mtls_endpoint=None, + client_cert_source=None, + credentials=None, + host=_DEFAULT_ENDPOINT, + ) - # Check that options can be customized. + # Check mTLS is triggered with api endpoint override. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: transport = gtc.return_value = mock.MagicMock() - client = {{ service.client_name }}( - client_options=options + client = {{ service.client_name }}(client_options=options) + transport.assert_called_once_with( + api_mtls_endpoint="squid.clam.whelk", + client_cert_source=None, + credentials=None, + host="squid.clam.whelk", ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + # Check mTLS is triggered if client_cert_source is provided. + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: + transport = gtc.return_value = mock.MagicMock() + client = {{ service.client_name }}(client_options=options) + transport.assert_called_once_with( + api_mtls_endpoint=_DEFAULT_MTLS_ENDPOINT, + client_cert_source=client_cert_source_callback, + credentials=None, + host=_DEFAULT_ENDPOINT, + ) def test_{{ service.client_name|snake_case }}_client_options_from_dict(): with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: @@ -63,7 +123,12 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict(): client = {{ service.client_name }}( client_options={'api_endpoint': 'squid.clam.whelk'} ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + transport.assert_called_once_with( + api_mtls_endpoint="squid.clam.whelk", + client_cert_source=None, + credentials=None, + host="squid.clam.whelk", + ) {% for method in service.methods.values() -%} @@ -154,7 +219,7 @@ def test_{{ method.name|snake_case }}_field_headers(): '__call__') as call: call.return_value = {{ method.output.ident }}() client.{{ method.name|snake_case }}(request) - + # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] @@ -200,57 +265,57 @@ def test_{{ method.name|snake_case }}_from_dict(): {% endif %} -{% if method.flattened_fields %} -def test_{{ method.name|snake_case }}_flattened(): - client = {{ service.client_name }}( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.{{ method.name|snake_case }}), - '__call__') as call: - # Designate an appropriate return value for the call. - {% if method.void -%} - call.return_value = None - {% elif method.lro -%} - call.return_value = operations_pb2.Operation(name='operations/op') - {% elif method.server_streaming -%} - call.return_value = iter([{{ method.output.ident }}()]) - {% else -%} - call.return_value = {{ method.output.ident }}() - {% endif %} - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.{{ method.name|snake_case }}( - {%- for field in method.flattened_fields.values() %} - {{ field.name }}={{ field.mock_value }}, - {%- endfor %} - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] +{% if method.flattened_fields %} +def test_{{ method.name|snake_case }}_flattened(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.{{ method.name|snake_case }}), + '__call__') as call: + # Designate an appropriate return value for the call. + {% if method.void -%} + call.return_value = None + {% elif method.lro -%} + call.return_value = operations_pb2.Operation(name='operations/op') + {% elif method.server_streaming -%} + call.return_value = iter([{{ method.output.ident }}()]) + {% else -%} + call.return_value = {{ method.output.ident }}() + {% endif %} + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.{{ method.name|snake_case }}( + {%- for field in method.flattened_fields.values() %} + {{ field.name }}={{ field.mock_value }}, + {%- endfor %} + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] {% for key, field in method.flattened_fields.items() -%} - assert args[0].{{ key }} == {{ field.mock_value }} - {% endfor %} - - -def test_{{ method.name|snake_case }}_flattened_error(): - client = {{ service.client_name }}( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.{{ method.name|snake_case }}( - {{ method.input.ident }}(), - {%- for field in method.flattened_fields.values() %} - {{ field.name }}={{ field.mock_value }}, - {%- endfor %} - ) + assert args[0].{{ key }} == {{ field.mock_value }} + {% endfor %} + + +def test_{{ method.name|snake_case }}_flattened_error(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.{{ method.name|snake_case }}( + {{ method.input.ident }}(), + {%- for field in method.flattened_fields.values() %} + {{ field.name }}={{ field.mock_value }}, + {%- endfor %} + ) {% endif %} @@ -509,10 +574,90 @@ def test_{{ service.name|snake_case }}_host_with_port(): def test_{{ service.name|snake_case }}_grpc_transport_channel(): channel = grpc.insecure_channel('http://localhost/') + + # Check that if channel is provided, mtls endpoint and client_cert_source + # won't be used. + callback = mock.MagicMock() transport = transports.{{ service.name }}GrpcTransport( + host="squid.clam.whelk", channel=channel, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=callback, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert not callback.called + + +@mock.patch("grpc.ssl_channel_credentials", autospec=True) +@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) +def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_client_cert_source( + grpc_create_channel, grpc_ssl_channel_cred +): + # Check that if channel is None, but api_mtls_endpoint and client_cert_source + # are provided, then a mTLS channel will be created. + mock_cred = mock.Mock() + + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + transport = transports.{{ service.name }}GrpcTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, ) - assert transport.grpc_channel is channel + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=( + {%- for scope in service.oauth_scopes %} + '{{ scope }}', + {%- endfor %} + ), + ) + assert transport.grpc_channel == mock_grpc_channel + + +@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) +def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(grpc_create_channel): + # Check that if channel and client_cert_source are None, but api_mtls_endpoint + # is provided, then a mTLS channel will be created with SSL ADC. + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + # Mock google.auth.transport.grpc.SslCredentials class. + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + mock_cred = mock.Mock() + transport = transports.{{ service.name }}GrpcTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=( + {%- for scope in service.oauth_scopes %} + '{{ scope }}', + {%- endfor %} + ), + ) + assert transport.grpc_channel == mock_grpc_channel {% if service.has_lro -%} From 23cf0be548854e0ad4004ea08ca5c815a9cecd73 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Fri, 3 Apr 2020 00:22:09 -0700 Subject: [PATCH 2/9] fix --- .../%name_%version/%sub/services/%service/client.py.j2 | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index fb9a383f11..0b76f91d28 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -50,7 +50,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") -_DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else }None{% endif %} +_DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %} _DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint(_DEFAULT_ENDPOINT) @@ -146,7 +146,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): transport will be created if client SSL credentials are found. Client SSL credentials are obtained from ``client_cert_source`` or application default SSL credentials. - + Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport creation failed for any reason. @@ -169,6 +169,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): self._transport = transport elif transport is not None: # transport is a string representing the name. + Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, host=client_options.api_endpoint ) From b01dc1e7144ef62f80b4bb30b288501d71e71415 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Fri, 3 Apr 2020 01:16:03 -0700 Subject: [PATCH 3/9] fix grpc client_cert_source name error --- .../%name_%version/%sub/services/%service/transports/grpc.py.j2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index 8e87d77d72..39862b2dc2 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -39,7 +39,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): credentials: credentials.Credentials = None, channel: grpc.Channel = None, api_mtls_endpoint: str = None, - client_cert_callback: Callable[[], Tuple[bytes, bytes]] = None) -> None: + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None: """Instantiate the transport. Args: From fcff3a4fd1ab764ffc08ae8fb4124ab46ad90646 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Fri, 3 Apr 2020 01:51:42 -0700 Subject: [PATCH 4/9] fix transport --- .../%name_%version/%sub/services/%service/client.py.j2 | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 0b76f91d28..6cbc173bcb 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -174,7 +174,9 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): credentials=credentials, host=client_options.api_endpoint ) else: - # transport is None, we will create a transport instance. + # transport is None, we will figure out if mTLS channel should be + # created, and create a proper {{ service.name }}GrpcTransport instance. + # Figure out if mTLS channel should be created. api_mtls_endpoint = None @@ -186,8 +188,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): elif client_options.client_cert_source: api_mtls_endpoint = _DEFAULT_MTLS_ENDPOINT - Transport = type(self).get_transport_class() - self._transport = Transport( + self._transport = {{ service.name }}GrpcTransport( credentials=credentials, host=client_options.api_endpoint, api_mtls_endpoint=api_mtls_endpoint, From 1fc99aeea41b4ac938652d2982f1e89e360b1c13 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Fri, 3 Apr 2020 14:24:35 -0700 Subject: [PATCH 5/9] fix test problems --- .../%sub/services/%service/client.py.j2 | 27 +++++++++---------- .../%name_%version/%sub/test_%service.py.j2 | 22 +++++++-------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 6cbc173bcb..ff3989bb51 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -167,26 +167,23 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): raise ValueError('When providing a transport instance, ' 'provide its credentials directly.') self._transport = transport - elif transport is not None: - # transport is a string representing the name. + elif transport is not None or ( + client_options.api_endpoint == _DEFAULT_ENDPOINT + and client_options.client_cert_source is None + ): + # Don't trigger mTLS. Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, host=client_options.api_endpoint ) else: - # transport is None, we will figure out if mTLS channel should be - # created, and create a proper {{ service.name }}GrpcTransport instance. - - # Figure out if mTLS channel should be created. - api_mtls_endpoint = None - - # If the user overrides endpoint, use it as the mTLS endpoint. If the - # user doesn't override endpoint, but provides client_cert_source, - # use the default mTLS endpoint. - if client_options.api_endpoint != _DEFAULT_ENDPOINT: - api_mtls_endpoint = client_options.api_endpoint - elif client_options.client_cert_source: - api_mtls_endpoint = _DEFAULT_MTLS_ENDPOINT + # Trigger mTLS. If the user overrides endpoint, use it as the mTLS + # endpoint, otherwise use the default mTLS endpoint. + api_mtls_endpoint = ( + (client_options.api_endpoint != _DEFAULT_ENDPOINT) + and client_options.api_endpoint + or _DEFAULT_MTLS_ENDPOINT + ) self._transport = {{ service.name }}GrpcTransport( credentials=credentials, diff --git a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 index 62e4737c8d..0b66602390 100644 --- a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 @@ -71,7 +71,7 @@ def test_{{ service.client_name|snake_case }}_client_options(): assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %} assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == _DEFAULT_ENDPOINT - # Check that the given channel is used. + # Check that if channel is provided we won't create a new one. with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: transport = transports.{{ service.name }}GrpcTransport( credentials=credentials.AnonymousCredentials() @@ -85,18 +85,16 @@ def test_{{ service.client_name|snake_case }}_client_options(): transport = gtc.return_value = mock.MagicMock() client = {{ service.client_name }}(client_options=options) transport.assert_called_once_with( - api_mtls_endpoint=None, - client_cert_source=None, credentials=None, host=_DEFAULT_ENDPOINT, ) # Check mTLS is triggered with api endpoint override. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: - transport = gtc.return_value = mock.MagicMock() + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) - transport.assert_called_once_with( + grpc_transport.assert_called_once_with( api_mtls_endpoint="squid.clam.whelk", client_cert_source=None, credentials=None, @@ -107,10 +105,10 @@ def test_{{ service.client_name|snake_case }}_client_options(): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: - transport = gtc.return_value = mock.MagicMock() + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) - transport.assert_called_once_with( + grpc_transport.assert_called_once_with( api_mtls_endpoint=_DEFAULT_MTLS_ENDPOINT, client_cert_source=client_cert_source_callback, credentials=None, @@ -118,12 +116,12 @@ def test_{{ service.client_name|snake_case }}_client_options(): ) def test_{{ service.client_name|snake_case }}_client_options_from_dict(): - with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: - transport = gtc.return_value = mock.MagicMock() + with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None client = {{ service.client_name }}( client_options={'api_endpoint': 'squid.clam.whelk'} ) - transport.assert_called_once_with( + grpc_transport.assert_called_once_with( api_mtls_endpoint="squid.clam.whelk", client_cert_source=None, credentials=None, From 969d6d6d939e3e44543d6cd78b1ef07f5ce21b87 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Fri, 3 Apr 2020 15:38:06 -0700 Subject: [PATCH 6/9] use unreleased python-api-core --- noxfile.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/noxfile.py b/noxfile.py index d3d84243b9..2733f4318e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -92,6 +92,10 @@ def showcase_unit(session): """Run the generated unit tests against the Showcase library.""" # Install pytest and gapic-generator-python + session.install( + "-e", + "git+https://github.com/googleapis/python-api-core.git@ca6c41cf460e505e6b228263170927270626222a#egg=google-api-core", + ) session.install('coverage', 'pytest', 'pytest-cov') session.install('.') From 2a8f4362f8fb77f22f5035b4cb55afb75c27d479 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Fri, 3 Apr 2020 16:20:24 -0700 Subject: [PATCH 7/9] update circleci yml --- .circleci/config.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b6891e3670..1bb5158969 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -253,7 +253,7 @@ jobs: ln -s /usr/src/protoc/bin/protoc /usr/local/bin/protoc - run: name: Run showcase tests. - command: nox -s showcase_alternative_templates + command: nox -s showcase_alternative_templates showcase-unit-3.6: docker: - image: python:3.6-slim @@ -263,7 +263,7 @@ jobs: name: Install system dependencies. command: | apt-get update - apt-get install -y curl pandoc unzip + apt-get install -y curl pandoc unzip git - run: name: Install protoc 3.7.1. command: | @@ -287,7 +287,7 @@ jobs: name: Install system dependencies. command: | apt-get update - apt-get install -y curl pandoc unzip + apt-get install -y curl pandoc unzip git - run: name: Install protoc 3.7.1. command: | @@ -311,7 +311,7 @@ jobs: name: Install system dependencies. command: | apt-get update - apt-get install -y curl pandoc unzip + apt-get install -y curl pandoc unzip git - run: name: Install protoc 3.7.1. command: | @@ -335,7 +335,7 @@ jobs: name: Install system dependencies. command: | apt-get update - apt-get install -y curl pandoc unzip + apt-get install -y curl pandoc unzip git - run: name: Install protoc 3.7.1. command: | @@ -359,7 +359,7 @@ jobs: name: Install system dependencies. command: | apt-get update - apt-get install -y curl pandoc unzip + apt-get install -y curl pandoc unzip git - run: name: Install protoc 3.7.1. command: | @@ -383,7 +383,7 @@ jobs: name: Install system dependencies. command: | apt-get update - apt-get install -y curl pandoc unzip + apt-get install -y curl pandoc unzip git - run: name: Install protoc 3.7.1. command: | From cd67c8fa0c035151b528722085ea53e9a8c36f67 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Mon, 6 Apr 2020 23:40:12 -0700 Subject: [PATCH 8/9] update the code --- .../%sub/services/%service/client.py.j2 | 80 ++++++++++--------- .../services/%service/transports/grpc.py.j2 | 9 +-- .../%name_%version/%sub/test_%service.py.j2 | 32 ++++---- 3 files changed, 62 insertions(+), 59 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index ff3989bb51..c981814807 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -2,6 +2,7 @@ {% block content %} from collections import OrderedDict +import re from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union import pkg_resources @@ -23,37 +24,6 @@ from .transports.base import {{ service.name }}Transport from .transports.grpc import {{ service.name }}GrpcTransport -def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if ( - api_endpoint is None - or api_endpoint.find("mtls.sandbox.googleapis.com") != -1 - or api_endpoint.find("mtls.googleapis.com") != -1 - or api_endpoint.find(".googleapis.com") == -1 - ): - # If the endpoint is already mTLS or the endpoint is not a googleapi, - # there is no need to convert. - return api_endpoint - - if api_endpoint.find(".sandbox.googleapis.com") != -1: - return api_endpoint.replace( - ".sandbox.googleapis.com", ".mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - -_DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %} -_DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint(_DEFAULT_ENDPOINT) - - class {{ service.client_name }}Meta(type): """Metaclass for the {{ service.name }} client. @@ -88,7 +58,40 @@ class {{ service.client_name }}Meta(type): class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): """{{ service.meta.doc|rst(width=72, indent=4) }}""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=_DEFAULT_ENDPOINT) + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %} + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=DEFAULT_ENDPOINT) @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): @@ -156,7 +159,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): # Set default api endpoint if not set. if client_options.api_endpoint is None: - client_options.api_endpoint = _DEFAULT_ENDPOINT + client_options.api_endpoint = {{ service.client_name }}.DEFAULT_ENDPOINT # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -168,7 +171,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): 'provide its credentials directly.') self._transport = transport elif transport is not None or ( - client_options.api_endpoint == _DEFAULT_ENDPOINT + client_options.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT and client_options.client_cert_source is None ): # Don't trigger mTLS. @@ -179,11 +182,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): else: # Trigger mTLS. If the user overrides endpoint, use it as the mTLS # endpoint, otherwise use the default mTLS endpoint. - api_mtls_endpoint = ( - (client_options.api_endpoint != _DEFAULT_ENDPOINT) - and client_options.api_endpoint - or _DEFAULT_MTLS_ENDPOINT - ) + if client_options.api_endpoint != {{ service.client_name }}.DEFAULT_ENDPOINT: + api_mtls_endpoint = client_options.api_endpoint + else: + api_mtls_endpoint = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT self._transport = {{ service.name }}GrpcTransport( credentials=credentials, diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index 39862b2dc2..61d0efbdf5 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -74,11 +74,10 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - host = ( - (":" in api_mtls_endpoint) - and api_mtls_endpoint - or (api_mtls_endpoint + ":443") - ) + if ":" in api_mtls_endpoint: + host = api_mtls_endpoint + else: + host = api_mtls_endpoint + ":443" # Create SSL credentials with client_cert_source or application # default SSL credentials. diff --git a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 index 0b66602390..1df314aabc 100644 --- a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 @@ -14,9 +14,6 @@ from google.auth import credentials from google.oauth2 import service_account from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }} from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports -from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _get_default_mtls_endpoint -from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _DEFAULT_ENDPOINT -from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.client import _DEFAULT_MTLS_ENDPOINT from google.api_core import client_options from google.api_core import grpc_helpers {% if service.has_lro -%} @@ -45,12 +42,12 @@ def test__get_default_mtls_endpoint(): sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" non_googleapi = "api.example.com" - assert _get_default_mtls_endpoint(None) == None - assert _get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert _get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert _get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert _get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert _get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert {{ service.client_name }}._get_default_mtls_endpoint(None) == None + assert {{ service.client_name }}._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert {{ service.client_name }}._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert {{ service.client_name }}._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert {{ service.client_name }}._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi def test_{{ service.client_name|snake_case }}_from_service_account_file(): @@ -69,7 +66,7 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file(): def test_{{ service.client_name|snake_case }}_client_options(): # Check the default options have their expected values. assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %} - assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == _DEFAULT_ENDPOINT + assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT # Check that if channel is provided we won't create a new one. with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc: @@ -86,7 +83,7 @@ def test_{{ service.client_name|snake_case }}_client_options(): client = {{ service.client_name }}(client_options=options) transport.assert_called_once_with( credentials=None, - host=_DEFAULT_ENDPOINT, + host={{ service.client_name }}.DEFAULT_ENDPOINT, ) # Check mTLS is triggered with api endpoint override. @@ -109,10 +106,10 @@ def test_{{ service.client_name|snake_case }}_client_options(): grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) grpc_transport.assert_called_once_with( - api_mtls_endpoint=_DEFAULT_MTLS_ENDPOINT, + api_mtls_endpoint={{ service.client_name }}.DEFAULT_MTLS_ENDPOINT, client_cert_source=client_cert_source_callback, credentials=None, - host=_DEFAULT_ENDPOINT, + host={{ service.client_name }}.DEFAULT_ENDPOINT, ) def test_{{ service.client_name|snake_case }}_client_options_from_dict(): @@ -624,8 +621,13 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_client_c assert transport.grpc_channel == mock_grpc_channel +@pytest.mark.parametrize( + "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] +) @mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(grpc_create_channel): +def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc( + grpc_create_channel, api_mtls_endpoint +): # Check that if channel and client_cert_source are None, but api_mtls_endpoint # is provided, then a mTLS channel will be created with SSL ADC. mock_grpc_channel = mock.Mock() @@ -642,7 +644,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(grpc transport = transports.{{ service.name }}GrpcTransport( host="squid.clam.whelk", credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", + api_mtls_endpoint=api_mtls_endpoint, client_cert_source=None, ) grpc_create_channel.assert_called_once_with( From b491018aff4a18cfa349c0d7ff7e03555f0759ca Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Tue, 7 Apr 2020 14:37:40 -0700 Subject: [PATCH 9/9] update the code --- .../%name_%version/%sub/services/%service/client.py.j2 | 10 ++++------ .../%sub/services/%service/transports/grpc.py.j2 | 5 +---- .../tests/unit/%name_%version/%sub/test_%service.py.j2 | 6 +++--- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index c981814807..f0b7d05381 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -159,7 +159,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): # Set default api endpoint if not set. if client_options.api_endpoint is None: - client_options.api_endpoint = {{ service.client_name }}.DEFAULT_ENDPOINT + client_options.api_endpoint = self.DEFAULT_ENDPOINT # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -171,7 +171,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): 'provide its credentials directly.') self._transport = transport elif transport is not None or ( - client_options.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT + client_options.api_endpoint == self.DEFAULT_ENDPOINT and client_options.client_cert_source is None ): # Don't trigger mTLS. @@ -182,10 +182,8 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): else: # Trigger mTLS. If the user overrides endpoint, use it as the mTLS # endpoint, otherwise use the default mTLS endpoint. - if client_options.api_endpoint != {{ service.client_name }}.DEFAULT_ENDPOINT: - api_mtls_endpoint = client_options.api_endpoint - else: - api_mtls_endpoint = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT + option_endpoint = client_options.api_endpoint + api_mtls_endpoint = self.DEFAULT_MTLS_ENDPOINT if option_endpoint == self.DEFAULT_ENDPOINT else option_endpoint self._transport = {{ service.name }}GrpcTransport( credentials=credentials, diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index 61d0efbdf5..eb47dbdc52 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -74,10 +74,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - if ":" in api_mtls_endpoint: - host = api_mtls_endpoint - else: - host = api_mtls_endpoint + ":443" + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" # Create SSL credentials with client_cert_source or application # default SSL credentials. diff --git a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 index 1df314aabc..e55cc99909 100644 --- a/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/%name_%version/%sub/test_%service.py.j2 @@ -83,7 +83,7 @@ def test_{{ service.client_name|snake_case }}_client_options(): client = {{ service.client_name }}(client_options=options) transport.assert_called_once_with( credentials=None, - host={{ service.client_name }}.DEFAULT_ENDPOINT, + host=client.DEFAULT_ENDPOINT, ) # Check mTLS is triggered with api endpoint override. @@ -106,10 +106,10 @@ def test_{{ service.client_name|snake_case }}_client_options(): grpc_transport.return_value = None client = {{ service.client_name }}(client_options=options) grpc_transport.assert_called_once_with( - api_mtls_endpoint={{ service.client_name }}.DEFAULT_MTLS_ENDPOINT, + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, client_cert_source=client_cert_source_callback, credentials=None, - host={{ service.client_name }}.DEFAULT_ENDPOINT, + host=client.DEFAULT_ENDPOINT, ) def test_{{ service.client_name|snake_case }}_client_options_from_dict():