Skip to content

Commit a798f00

Browse files
chore: increase async rpc performance (#1755)
1 parent 62855c1 commit a798f00

File tree

38 files changed

+11047
-1583
lines changed

38 files changed

+11047
-1583
lines changed

gapic/templates/%namespace/%name_%version/%sub/services/%service/_client_macros.j2

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@
8787
{% if not method.client_streaming %}
8888
# Create or coerce a protobuf request object.
8989
{% if method.flattened_fields %}
90-
# Quick check: If we got a request object, we should *not* have
91-
# gotten any keyword arguments that map to the request.
90+
# - Quick check: If we got a request object, we should *not* have
91+
# gotten any keyword arguments that map to the request.
9292
has_flattened_params = any([{{ method.flattened_fields.values()|join(", ", attribute="name") }}])
9393
if request is not None and has_flattened_params:
9494
raise ValueError('If the `request` argument is set, then none of '
@@ -97,17 +97,15 @@
9797
{% endif %}
9898
{% if method.input.ident.package != method.ident.package %}{# request lives in a different package, so there is no proto wrapper #}
9999
if isinstance(request, dict):
100-
# The request isn't a proto-plus wrapped type,
101-
# so it must be constructed via keyword expansion.
100+
# - The request isn't a proto-plus wrapped type,
101+
# so it must be constructed via keyword expansion.
102102
request = {{ method.input.ident }}(**request)
103103
elif not request:
104104
# Null request, just make one.
105105
request = {{ method.input.ident }}()
106106
{% else %}
107-
# Minor optimization to avoid making a copy if the user passes
108-
# in a {{ method.input.ident }}.
109-
# There's no risk of modifying the input as we've already verified
110-
# there are no flattened fields.
107+
# - Use the request object if provided (there's no risk of modifying the input as
108+
# there are no flattened fields), or create one.
111109
if not isinstance(request, {{ method.input.ident }}):
112110
request = {{ method.input.ident }}(request)
113111
{% endif %}{# different request package #}

gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -314,23 +314,26 @@ class {{ service.async_client_name }}:
314314
{% if not method.client_streaming %}
315315
# Create or coerce a protobuf request object.
316316
{% if method.flattened_fields %}
317-
# Quick check: If we got a request object, we should *not* have
318-
# gotten any keyword arguments that map to the request.
317+
# - Quick check: If we got a request object, we should *not* have
318+
# gotten any keyword arguments that map to the request.
319319
has_flattened_params = any([{{ method.flattened_fields.values()|join(", ", attribute="name") }}])
320320
if request is not None and has_flattened_params:
321321
raise ValueError("If the `request` argument is set, then none of "
322322
"the individual field arguments should be set.")
323323

324324
{% endif %}
325325
{% if method.input.ident.package != method.ident.package %} {# request lives in a different package, so there is no proto wrapper #}
326-
# The request isn't a proto-plus wrapped type,
327-
# so it must be constructed via keyword expansion.
326+
# - The request isn't a proto-plus wrapped type,
327+
# so it must be constructed via keyword expansion.
328328
if isinstance(request, dict):
329329
request = {{ method.input.ident }}(**request)
330330
elif not request:
331331
request = {{ method.input.ident }}({% if method.flattened_fields %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %})
332332
{% else %}
333-
request = {{ method.input.ident }}(request)
333+
# - Use the request object if provided (there's no risk of modifying the input as
334+
# there are no flattened fields), or create one.
335+
if not isinstance(request, {{ method.input.ident }}):
336+
request = {{ method.input.ident }}(request)
334337
{% endif %} {# different request package #}
335338

336339
{# Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
@@ -357,26 +360,9 @@ class {{ service.async_client_name }}:
357360

358361
# Wrap the RPC method; this adds retry and timeout information,
359362
# and friendly error handling.
360-
rpc = gapic_v1.method_async.wrap_method(
361-
self._client._transport.{{ method.transport_safe_name|snake_case }},
362-
{% if method.retry %}
363-
default_retry=retries.AsyncRetry(
364-
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
365-
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
366-
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
367-
predicate=retries.if_exception_type(
368-
{% for ex in method.retry.retryable_exceptions|sort(attribute="__name__") %}
369-
core_exceptions.{{ ex.__name__ }},
370-
{% endfor %}
371-
),
372-
deadline={{ method.timeout }},
373-
),
374-
{% endif %}
375-
default_timeout={{ method.timeout }},
376-
client_info=DEFAULT_CLIENT_INFO,
377-
)
378-
{% if method.field_headers %}
363+
rpc = self._client._transport._wrapped_methods[self._client._transport.{{ method.transport_safe_name|snake_case }}]
379364

365+
{% if method.field_headers %}
380366
# Certain fields should be provided within the metadata header;
381367
# add these here.
382368
metadata = tuple(metadata) + (

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/base.py.j2

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,15 @@ class {{ service.name }}Transport(abc.ABC):
149149
self.{{ method.transport_safe_name|snake_case }},
150150
{% if method.retry %}
151151
default_retry=retries.Retry(
152-
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %}
153-
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %}
154-
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %}
152+
{% if method.retry.initial_backoff %}
153+
initial={{ method.retry.initial_backoff }},
154+
{% endif %}
155+
{% if method.retry.max_backoff %}
156+
maximum={{ method.retry.max_backoff }},
157+
{% endif %}
158+
{% if method.retry.backoff_multiplier %}
159+
multiplier={{ method.retry.backoff_multiplier }},
160+
{% endif %}
155161
predicate=retries.if_exception_type(
156162
{% for ex in method.retry.retryable_exceptions|sort(attribute='__name__') %}
157163
core_exceptions.{{ ex.__name__ }},

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union
77

88
from google.api_core import gapic_v1
99
from google.api_core import grpc_helpers_async
10+
from google.api_core import exceptions as core_exceptions
11+
from google.api_core import retry_async as retries
1012
{% if service.has_lro %}
1113
from google.api_core import operations_v1
1214
{% endif %}
@@ -382,6 +384,37 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
382384
return self._stubs["test_iam_permissions"]
383385
{% endif %}
384386

387+
def _prep_wrapped_messages(self, client_info):
388+
""" Precompute the wrapped methods, overriding the base class method to use async wrappers."""
389+
self._wrapped_methods = {
390+
{% for method in service.methods.values() %}
391+
self.{{ method.transport_safe_name|snake_case }}: gapic_v1.method_async.wrap_method(
392+
self.{{ method.transport_safe_name|snake_case }},
393+
{% if method.retry %}
394+
default_retry=retries.AsyncRetry(
395+
{% if method.retry.initial_backoff %}
396+
initial={{ method.retry.initial_backoff }},
397+
{% endif %}
398+
{% if method.retry.max_backoff %}
399+
maximum={{ method.retry.max_backoff }},
400+
{% endif %}
401+
{% if method.retry.backoff_multiplier %}
402+
multiplier={{ method.retry.backoff_multiplier }},
403+
{% endif %}
404+
predicate=retries.if_exception_type(
405+
{% for ex in method.retry.retryable_exceptions|sort(attribute='__name__') %}
406+
core_exceptions.{{ ex.__name__ }},
407+
{% endfor %}
408+
),
409+
deadline={{ method.timeout }},
410+
),
411+
{% endif %}
412+
default_timeout={{ method.timeout }},
413+
client_info=client_info,
414+
),
415+
{% endfor %} {# service.methods.values() #}
416+
}
417+
385418
def close(self):
386419
return self.grpc_channel.close()
387420

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,54 @@ def test_{{ method_name }}_non_empty_request_with_auto_populated_field():
196196
)
197197
{% endif %}
198198

199+
def test_{{ method_name }}_use_cached_wrapped_rpc():
200+
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
201+
# instead of constructing them on each call
202+
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
203+
client = {{ service.client_name }}(
204+
credentials=ga_credentials.AnonymousCredentials(),
205+
transport="grpc",
206+
)
207+
208+
# Should wrap all calls on client creation
209+
assert wrapper_fn.call_count > 0
210+
wrapper_fn.reset_mock()
211+
212+
# Ensure method has been cached
213+
assert client._transport.{{method.transport_safe_name|snake_case}} in client._transport._wrapped_methods
214+
215+
# Replace cached wrapped function with mock
216+
mock_rpc = mock.Mock()
217+
client._transport._wrapped_methods[client._transport.{{method.transport_safe_name|snake_case}}] = mock_rpc
218+
219+
{% if method.client_streaming %}
220+
request = [{}]
221+
client.{{ method.safe_name|snake_case }}(request)
222+
{% else %}
223+
request = {}
224+
client.{{ method_name }}(request)
225+
{% endif %}
226+
227+
# Establish that the underlying gRPC stub method was called.
228+
assert mock_rpc.call_count == 1
229+
230+
{% if method.lro or method.extended_lro %}
231+
# Operation methods build a cached wrapper on first rpc call
232+
# subsequent calls should use the cached wrapper
233+
wrapper_fn.reset_mock()
234+
{% endif %}
235+
236+
{% if method.client_streaming %}
237+
client.{{ method.safe_name|snake_case }}(request)
238+
{% else %}
239+
client.{{ method_name }}(request)
240+
{% endif %}
241+
242+
243+
# Establish that a new wrapper was not created for this call
244+
assert wrapper_fn.call_count == 0
245+
assert mock_rpc.call_count == 2
246+
199247
{% if not full_extended_lro %}
200248
{% if not method.client_streaming %}
201249
@pytest.mark.asyncio
@@ -253,6 +301,58 @@ async def test_{{ method_name }}_empty_call_async():
253301
assert args[0] == {{ method.input.ident }}()
254302
{% endif %}
255303

304+
@pytest.mark.asyncio
305+
async def test_{{ method_name }}_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"):
306+
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
307+
# instead of constructing them on each call
308+
with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn:
309+
client = {{ service.async_client_name }}(
310+
credentials=ga_credentials.AnonymousCredentials(),
311+
transport=transport,
312+
)
313+
314+
# Should wrap all calls on client creation
315+
assert wrapper_fn.call_count > 0
316+
wrapper_fn.reset_mock()
317+
318+
# Ensure method has been cached
319+
assert client._client._transport.{{method.transport_safe_name|snake_case}} in client._client._transport._wrapped_methods
320+
321+
# Replace cached wrapped function with mock
322+
class AwaitableMock(mock.AsyncMock):
323+
def __await__(self):
324+
self.await_count += 1
325+
return iter([])
326+
mock_object = AwaitableMock()
327+
client._client._transport._wrapped_methods[client._client._transport.{{method.transport_safe_name|snake_case}}] = mock_object
328+
329+
{% if method.client_streaming %}
330+
request = [{}]
331+
await client.{{ method.name|snake_case }}(request)
332+
{% else %}
333+
request = {}
334+
await client.{{ method_name }}(request)
335+
{% endif %}
336+
337+
# Establish that the underlying gRPC stub method was called.
338+
assert mock_object.call_count == 1
339+
340+
{% if method.lro or method.extended_lro %}
341+
# Operation methods build a cached wrapper on first rpc call
342+
# subsequent calls should use the cached wrapper
343+
wrapper_fn.reset_mock()
344+
{% endif %}
345+
346+
{% if method.client_streaming %}
347+
await client.{{ method.name|snake_case }}(request)
348+
{% else %}
349+
await client.{{ method_name }}(request)
350+
{% endif %}
351+
352+
# Establish that a new wrapper was not created for this call
353+
assert wrapper_fn.call_count == 0
354+
assert mock_object.call_count == 2
355+
256356
@pytest.mark.asyncio
257357
async def test_{{ method_name }}_async(transport: str = 'grpc_asyncio', request_type={{ method.input.ident }}):
258358
{% with auto_populated_field_sample_value = "explicit value for autopopulate-able field" %}
@@ -1220,6 +1320,53 @@ def test_{{ method_name }}_rest(request_type):
12201320
{% endfor %}
12211321
{% endif %}
12221322

1323+
def test_{{ method_name }}_rest_use_cached_wrapped_rpc():
1324+
# Clients should use _prep_wrapped_messages to create cached wrapped rpcs,
1325+
# instead of constructing them on each call
1326+
with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn:
1327+
client = {{ service.client_name }}(
1328+
credentials=ga_credentials.AnonymousCredentials(),
1329+
transport="rest",
1330+
)
1331+
1332+
# Should wrap all calls on client creation
1333+
assert wrapper_fn.call_count > 0
1334+
wrapper_fn.reset_mock()
1335+
1336+
# Ensure method has been cached
1337+
assert client._transport.{{method.transport_safe_name|snake_case}} in client._transport._wrapped_methods
1338+
1339+
# Replace cached wrapped function with mock
1340+
mock_rpc = mock.Mock()
1341+
client._transport._wrapped_methods[client._transport.{{method.transport_safe_name|snake_case}}] = mock_rpc
1342+
1343+
{% if method.client_streaming %}
1344+
request = [{}]
1345+
client.{{ method.safe_name|snake_case }}(request)
1346+
{% else %}
1347+
request = {}
1348+
client.{{ method_name }}(request)
1349+
{% endif %}
1350+
1351+
# Establish that the underlying gRPC stub method was called.
1352+
assert mock_rpc.call_count == 1
1353+
1354+
{% if method.lro or method.extended_lro %}
1355+
# Operation methods build a cached wrapper on first rpc call
1356+
# subsequent calls should use the cached wrapper
1357+
wrapper_fn.reset_mock()
1358+
{% endif %}
1359+
1360+
{% if method.client_streaming %}
1361+
client.{{ method.safe_name|snake_case }}(request)
1362+
{% else %}
1363+
client.{{ method_name }}(request)
1364+
{% endif %}
1365+
1366+
# Establish that a new wrapper was not created for this call
1367+
assert wrapper_fn.call_count == 0
1368+
assert mock_rpc.call_count == 2
1369+
12231370

12241371
{% if method.input.required_fields %}
12251372
def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ident }}):

0 commit comments

Comments
 (0)