Skip to content

Commit

Permalink
feat: support quota project override via client options (#496)
Browse files Browse the repository at this point in the history
* feat: support quota project override via client options

* chore(deps): bump api-core version

* Update setup.py.j2
  • Loading branch information
busunkim96 authored Jul 22, 2020
1 parent f86a47b commit bbc6b36
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
scopes=client_options.scopes,
api_mtls_endpoint=client_options.api_endpoint,
client_cert_source=client_options.client_cert_source,
quota_project_id=client_options.quota_project_id,
)

{% for method in service.methods.values() -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class {{ service.name }}Transport(abc.ABC):
credentials: credentials.Credentials = None,
credentials_file: typing.Optional[str] = None,
scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES,
quota_project_id: typing.Optional[str] = None,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -49,6 +50,8 @@ class {{ service.name }}Transport(abc.ABC):
be loaded with :func:`google.auth.load_credentials_from_file`.
This argument is mutually exclusive with credentials.
scope (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ':' not in host:
Expand All @@ -61,9 +64,14 @@ class {{ service.name }}Transport(abc.ABC):
raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive")

if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(credentials_file, scopes=scopes)
credentials, _ = auth.load_credentials_from_file(
credentials_file,
scopes=scopes,
quota_project_id=quota_project_id
)

elif credentials is None:
credentials, _ = auth.default(scopes=scopes)
credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id)

# Save the credentials.
self._credentials = credentials
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes: Sequence[str] = None,
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None) -> None:
"""Instantiate the transport.

Args:
Expand All @@ -71,6 +72,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
Expand All @@ -89,7 +92,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES)
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
Expand All @@ -108,14 +111,16 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials_file=credentials_file,
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

# Run the base constructor.
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes or self.AUTH_SCOPES
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

self._stubs = {} # type: Dict[str, Callable]
Expand All @@ -126,6 +131,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials: credentials.Credentials = None,
credentials_file: str = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
**kwargs) -> grpc.Channel:
"""Create and return a gRPC channel object.
Args:
Expand All @@ -141,6 +147,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
service. These are only used when credentials are not specified and
are passed to :func:`google.auth.default`.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
kwargs (Optional[dict]): Keyword arguments, which are passed to the
channel creation.
Returns:
Expand All @@ -156,6 +164,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
credentials: credentials.Credentials = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
**kwargs) -> aio.Channel:
"""Create and return a gRPC AsyncIO channel object.
Args:
Expand All @@ -60,6 +61,8 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
service. These are only used when credentials are not specified and
are passed to :func:`google.auth.default`.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
kwargs (Optional[dict]): Keyword arguments, which are passed to the
channel creation.
Returns:
Expand All @@ -71,6 +74,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**kwargs
)

Expand All @@ -81,7 +85,9 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
scopes: Optional[Sequence[str]] = None,
channel: aio.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
) -> None:
"""Instantiate the transport.

Args:
Expand Down Expand Up @@ -109,6 +115,8 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
Expand Down Expand Up @@ -143,14 +151,16 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
credentials_file=credentials_file,
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

# Run the base constructor.
super().__init__(
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes or self.AUTH_SCOPES
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
)

self._stubs = {}
Expand Down
2 changes: 1 addition & 1 deletion gapic/templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ setuptools.setup(
platforms='Posix; MacOS X; Windows',
include_package_data=True,
install_requires=(
'google-api-core[grpc] >= 1.21.0, < 2.0.0dev',
'google-api-core[grpc] >= 1.22.0, < 2.0.0dev',
'libcst >= 0.2.5',
'proto-plus >= 1.1.0',
{%- if api.requires_package(('google', 'iam', 'v1')) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -127,6 +128,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
Expand All @@ -142,6 +144,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -158,6 +161,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=client_cert_source_callback,
quota_project_id=None,

)

Expand All @@ -175,6 +179,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
Expand All @@ -191,15 +196,29 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
# unsupported value.
os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported"
with pytest.raises(MutualTLSChannelError):
client = client_class()
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}):
with pytest.raises(MutualTLSChannelError):
client = client_class()

del os.environ["GOOGLE_API_USE_MTLS"]
# Check the case quota_project_id is provided
options = client_options.ClientOptions(quota_project_id="octopus")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
host=client.DEFAULT_ENDPOINT,
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id="octopus",
)


@pytest.mark.parametrize("client_class,transport_class,transport_name", [
Expand All @@ -221,6 +240,7 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class
scopes=["1", "2"],
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)


Expand All @@ -243,6 +263,7 @@ def test_{{ service.client_name|snake_case }}_client_options_credentials_file(cl
scopes=None,
api_mtls_endpoint=client.DEFAULT_ENDPOINT,
client_cert_source=None,
quota_project_id=None,
)


Expand All @@ -259,6 +280,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
scopes=None,
api_mtls_endpoint="squid.clam.whelk",
client_cert_source=None,
quota_project_id=None,
)


Expand Down Expand Up @@ -1001,12 +1023,15 @@ def test_{{ service.name|snake_case }}_base_transport_with_credentials_file():
load_creds.return_value = (credentials.AnonymousCredentials(), None)
transport = transports.{{ service.name }}Transport(
credentials_file="credentials.json",
quota_project_id="octopus",
)
load_creds.assert_called_once_with("credentials.json", scopes=(
{%- for scope in service.oauth_scopes %}
'{{ scope }}',
{%- endfor %}
))
),
quota_project_id="octopus",
)


def test_{{ service.name|snake_case }}_auth_adc():
Expand All @@ -1017,22 +1042,23 @@ def test_{{ service.name|snake_case }}_auth_adc():
adc.assert_called_once_with(scopes=(
{%- for scope in service.oauth_scopes %}
'{{ scope }}',
{%- endfor %}
))
{%- endfor %}),
quota_project_id=None,
)


def test_{{ service.name|snake_case }}_transport_auth_adc():
# If credentials and host are not provided, the transport class should use
# ADC credentials.
with mock.patch.object(auth, 'default') as adc:
adc.return_value = (credentials.AnonymousCredentials(), None)
transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk")
transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk", quota_project_id="octopus")
adc.assert_called_once_with(scopes=(
{%- for scope in service.oauth_scopes %}
'{{ scope }}',
{%- endfor %}
))

{%- endfor %}),
quota_project_id="octopus",
)

def test_{{ service.name|snake_case }}_host_no_port():
{% with host = (service.host|default('localhost', true)).split(':')[0] -%}
Expand Down Expand Up @@ -1122,6 +1148,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_client_c
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down Expand Up @@ -1160,6 +1187,7 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel_mtls_with_
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down Expand Up @@ -1200,6 +1228,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel_mtls_with_adc(
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down Expand Up @@ -1240,6 +1269,7 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel_mtls_with_
{%- endfor %}
),
ssl_credentials=mock_ssl_cred,
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
click==7.1.2
google-api-core==1.21.0
google-api-core==1.22.0
googleapis-common-protos==1.52.0
jinja2==2.11.2
MarkupSafe==1.1.1
Expand Down

0 comments on commit bbc6b36

Please sign in to comment.