Skip to content

Commit

Permalink
feat: allow user-provided client info (#573)
Browse files Browse the repository at this point in the history
Fix for googleapis/python-kms#37, #566, and similar.
  • Loading branch information
software-dov authored Aug 17, 2020
1 parent 7c2bab7 commit b2e5274
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from google.oauth2 import service_account # type: ignore
{% endfor -%}
{% endfor -%}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .transports.grpc import {{ service.name }}GrpcTransport


Expand Down Expand Up @@ -135,6 +135,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -160,6 +161,11 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
(2) The ``client_cert_source`` property is used to provide client
SSL credentials for mutual TLS transport. If not provided, the
default SSL credentials will be used if present.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
Expand Down Expand Up @@ -209,6 +215,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
host=client_options.api_endpoint,
api_mtls_endpoint=client_options.api_endpoint,
client_cert_source=client_options.client_cert_source,
client_info=client_info,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ from google.auth import credentials # type: ignore
{% endfilter %}

try:
_client_info = gapic_v1.client_info.ClientInfo(
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()


class {{ service.name }}Transport(metaclass=abc.ABCMeta):
Expand All @@ -43,6 +43,7 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: credentials.Credentials = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the transport.

Expand All @@ -54,6 +55,11 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
credentials identify the application to the service; if none
are specified, the client will attempt to ascertain the
credentials from the environment.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -69,9 +75,9 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
self._credentials = credentials

# Lifted into its own function so it can be stubbed out during tests.
self._prep_wrapped_messages()
self._prep_wrapped_messages(client_info)

def _prep_wrapped_messages(self):
def _prep_wrapped_messages(self, client_info):
# Precomputed wrapped methods
self._wrapped_methods = {
{% for method in service.methods.values() -%}
Expand All @@ -92,7 +98,7 @@ class {{ service.name }}Transport(metaclass=abc.ABCMeta):
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
client_info=client_info,
),
{% endfor %} {# precomputed wrappers loop #}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from google.api_core import grpc_helpers # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
{%- endif %}
from google.api_core import gapic_v1 # type: ignore
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
Expand All @@ -20,7 +21,7 @@ import grpc # type: ignore
{{ method.output.ident.python_import }}
{% endfor -%}
{% endfilter %}
from .base import {{ service.name }}Transport
from .base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO


class {{ service.name }}GrpcTransport({{ service.name }}Transport):
Expand All @@ -40,7 +41,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials: credentials.Credentials = None,
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the transport.

Args:
Expand All @@ -62,6 +65,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
Expand Down Expand Up @@ -101,7 +109,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
self._stubs = {} # type: Dict[str, Callable]

# Run the base constructor.
super().__init__(host=host, credentials=credentials)
super().__init__(
host=host,
credentials=credentials,
client_info=client_info,
)


@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ from google.api_core import future
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% endif -%}
{% if service.has_pagers -%}
from google.api_core import gapic_v1
{% endif -%}
{% for method in service.methods.values() -%}
{% for ref_type in method.ref_types
if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation')
Expand Down Expand Up @@ -109,6 +107,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -122,6 +121,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -135,6 +135,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -149,6 +150,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=client_cert_source_callback,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -163,6 +165,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_MTLS_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -177,6 +180,7 @@ def test_{{ service.client_name|snake_case }}_client_options():
client_cert_source=None,
credentials=None,
host=client.DEFAULT_ENDPOINT,
client_info=transports.base.DEFAULT_CLIENT_INFO,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
Expand All @@ -197,6 +201,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
client_cert_source=None,
credentials=None,
host="squid.clam.whelk",
client_info=transports.base.DEFAULT_CLIENT_INFO,
)


Expand Down Expand Up @@ -769,4 +774,23 @@ def test_parse_{{ message.resource_type|snake_case }}_path():
{% endwith -%}
{% endfor -%}

def test_client_withDEFAULT_CLIENT_INFO():
client_info = gapic_v1.client_info.ClientInfo()

with mock.patch.object(transports.{{ service.name }}Transport, '_prep_wrapped_messages') as prep:
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
client_info=client_info,
)
prep.assert_called_once_with(client_info)

with mock.patch.object(transports.{{ service.name }}Transport, '_prep_wrapped_messages') as prep:
transport_class = {{ service.client_name }}.get_transport_class()
transport = transport_class(
credentials=credentials.AnonymousCredentials(),
client_info=client_info,
)
prep.assert_called_once_with(client_info)


{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from google.iam.v1 import iam_policy_pb2 as iam_policy # type: ignore
from google.iam.v1 import policy_pb2 as policy # type: ignore
{% endif %}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}
from .client import {{ service.client_name }}

Expand All @@ -52,6 +52,7 @@ class {{ service.async_client_name }}:
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = 'grpc_asyncio',
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand Down Expand Up @@ -87,6 +88,8 @@ class {{ service.async_client_name }}:
credentials=credentials,
transport=transport,
client_options=client_options,
client_info=client_info,

)

{% for method in service.methods.values() -%}
Expand Down Expand Up @@ -202,7 +205,7 @@ class {{ service.async_client_name }}:
),
{%- endif %}
default_timeout={{ method.timeout }},
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)
{%- if method.field_headers %}

Expand Down Expand Up @@ -352,7 +355,7 @@ class {{ service.async_client_name }}:
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.set_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -459,7 +462,7 @@ class {{ service.async_client_name }}:
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.get_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -510,7 +513,7 @@ class {{ service.async_client_name }}:
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.test_iam_permissions,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand All @@ -527,13 +530,13 @@ class {{ service.async_client_name }}:
{% endif %}

try:
_client_info = gapic_v1.client_info.ClientInfo(
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()


__all__ = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ from google.iam.v1 import iam_policy_pb2 as iam_policy # type: ignore
from google.iam.v1 import policy_pb2 as policy # type: ignore
{% endif %}
{% endfilter %}
from .transports.base import {{ service.name }}Transport
from .transports.base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .transports.grpc import {{ service.grpc_transport_name }}
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }}

Expand Down Expand Up @@ -141,6 +141,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -166,7 +167,12 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
(2) The ``client_cert_source`` property is used to provide client
SSL credentials for mutual TLS transport. If not provided, the
default SSL credentials will be used if present.

client_info (google.api_core.gapic_v1.client_info.ClientInfo):
The client info used to send a user-agent string along with
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
creation failed for any reason.
Expand Down Expand Up @@ -219,6 +225,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
api_mtls_endpoint=client_options.api_endpoint,
client_cert_source=client_options.client_cert_source,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)


Expand Down Expand Up @@ -471,7 +478,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
rpc = gapic_v1.method.wrap_method(
self._transport.set_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -578,7 +585,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
rpc = gapic_v1.method.wrap_method(
self._transport.get_iam_policy,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand Down Expand Up @@ -629,7 +636,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
rpc = gapic_v1.method.wrap_method(
self._transport.test_iam_permissions,
default_timeout=None,
client_info=_client_info,
client_info=DEFAULT_CLIENT_INFO,
)

# Certain fields should be provided within the metadata header;
Expand All @@ -647,13 +654,13 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):


try:
_client_info = gapic_v1.client_info.ClientInfo(
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
gapic_version=pkg_resources.get_distribution(
'{{ api.naming.warehouse_package_name }}',
).version,
)
except pkg_resources.DistributionNotFound:
_client_info = gapic_v1.client_info.ClientInfo()
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()


__all__ = (
Expand Down
Loading

0 comments on commit b2e5274

Please sign in to comment.