Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix mTLS logic #374

Merged
merged 2 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)
DEFAULT_OPTIONS = ClientOptions.ClientOptions(api_endpoint=DEFAULT_ENDPOINT)

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
Expand Down Expand Up @@ -126,7 +125,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
def __init__(self, *,
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = DEFAULT_OPTIONS,
client_options: ClientOptions = None,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -144,11 +143,10 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
default endpoint provided by the client.
(2) If ``transport`` argument is None, ``client_options`` can be
used to create a mutual TLS transport. If ``api_endpoint`` is
provided and different from the default endpoint, or the
``client_cert_source`` property is provided, mutual TLS
transport will be created if client SSL credentials are found.
Client SSL credentials are obtained from ``client_cert_source``
or application default SSL credentials.
provided or the ``client_cert_source`` property is provided,
mutual TLS transport will be created if client SSL credentials
are found. Client SSL credentials are obtained from
``client_cert_source`` or application default SSL credentials.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
Expand All @@ -157,10 +155,6 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
if isinstance(client_options, dict):
client_options = ClientOptions.from_dict(client_options)

# Set default api endpoint if not set.
if client_options.api_endpoint is None:
client_options.api_endpoint = self.DEFAULT_ENDPOINT

# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
Expand All @@ -170,24 +164,27 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
raise ValueError('When providing a transport instance, '
'provide its credentials directly.')
self._transport = transport
elif transport is not None or (
client_options.api_endpoint == self.DEFAULT_ENDPOINT
elif client_options is None or (
client_options.api_endpoint == None
and client_options.client_cert_source is None
):
# Don't trigger mTLS.
# Don't trigger mTLS if we get an empty ClientOptions.
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials, host=client_options.api_endpoint
credentials=credentials, host=self.DEFAULT_ENDPOINT
)
else:
# Trigger mTLS. If the user overrides endpoint, use it as the mTLS
# endpoint, otherwise use the default mTLS endpoint.
option_endpoint = client_options.api_endpoint
api_mtls_endpoint = self.DEFAULT_MTLS_ENDPOINT if option_endpoint == self.DEFAULT_ENDPOINT else option_endpoint
# We have a non-empty ClientOptions, trigger mTLS. If the user
# doesn't provide endpoint, use the default mTLS endpoint.
if client_options.api_endpoint:
api_endpoint = api_mtls_endpoint = client_options.api_endpoint
else:
api_endpoint = self.DEFAULT_ENDPOINT
api_mtls_endpoint = self.DEFAULT_MTLS_ENDPOINT
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

self._transport = {{ service.name }}GrpcTransport(
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
credentials=credentials,
host=client_options.api_endpoint,
host=api_endpoint,
api_mtls_endpoint=api_mtls_endpoint,
client_cert_source=client_options.client_cert_source,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():


def test_{{ service.client_name|snake_case }}_client_options():
# Check the default options have their expected values.
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
assert {{ service.client_name }}.DEFAULT_OPTIONS.api_endpoint == {{ service.client_name }}.DEFAULT_ENDPOINT

# Check that if channel is provided we won't create a new one.
with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
transport = transports.{{ service.name }}GrpcTransport(
Expand Down