Skip to content

Commit ddcea4b

Browse files
authored
Cross patch #470 to ads templates (#510)
1 parent 3713a4b commit ddcea4b

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

packages/gapic-generator/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
369369
method=rpc,
370370
request=request,
371371
response=response,
372+
metadata=metadata,
372373
)
373374
{%- endif %}
374375
{%- if not method.void %}

packages/gapic-generator/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/pagers.py.j2

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{# This lives within the loop in order to ensure that this template
77
is empty if there are no paged methods.
88
-#}
9-
from typing import Any, Callable, Iterable
9+
from typing import Any, Callable, Iterable, Sequence, Tuple
1010

1111
{% filter sort_lines -%}
1212
{% for method in service.methods.values() | selectattr('paged_result_field') -%}
@@ -35,10 +35,10 @@ class {{ method.name }}Pager:
3535
the most recent response is retained, and thus used for attribute lookup.
3636
"""
3737
def __init__(self,
38-
method: Callable[[{{ method.input.ident }}],
39-
{{ method.output.ident }}],
38+
method: Callable[..., {{ method.output.ident }}],
4039
request: {{ method.input.ident }},
41-
response: {{ method.output.ident }}):
40+
response: {{ method.output.ident }},
41+
metadata: Sequence[Tuple[str, str]] = ())):
4242
"""Instantiate the pager.
4343

4444
Args:
@@ -48,10 +48,13 @@ class {{ method.name }}Pager:
4848
The initial request object.
4949
response (:class:`{{ method.output.ident.sphinx }}`):
5050
The initial response object.
51+
metadata (Sequence[Tuple[str, str]]): Strings which should be
52+
sent along with the request as metadata.
5153
"""
5254
self._method = method
5355
self._request = {{ method.input.ident }}(request)
5456
self._response = response
57+
self._metadata = metadata
5558

5659
def __getattr__(self, name: str) -> Any:
5760
return getattr(self._response, name)
@@ -61,7 +64,7 @@ class {{ method.name }}Pager:
6164
yield self._response
6265
while self._response.next_page_token:
6366
self._request.page_token = self._response.next_page_token
64-
self._response = self._method(self._request)
67+
self._response = self._method(self._request, metadata=self._metadata)
6568
yield self._response
6669

6770
def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}:

packages/gapic-generator/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ from google.api_core import future
2424
from google.api_core import operations_v1
2525
from google.longrunning import operations_pb2
2626
{% endif -%}
27+
{% if service.has_pagers -%}
28+
from google.api_core import gapic_v1
29+
{% endif -%}
2730
{% for method in service.methods.values() -%}
2831
{% for ref_type in method.ref_types
2932
if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation')
@@ -442,9 +445,24 @@ def test_{{ method.name|snake_case }}_pager():
442445
),
443446
RuntimeError,
444447
)
445-
results = [i for i in client.{{ method.name|snake_case }}(
446-
request={},
447-
)]
448+
449+
metadata = ()
450+
{% if method.field_headers -%}
451+
metadata = tuple(metadata) + (
452+
gapic_v1.routing_header.to_grpc_metadata((
453+
{%- for field_header in method.field_headers %}
454+
{%- if not method.client_streaming %}
455+
('{{ field_header }}', ''),
456+
{%- endif %}
457+
{%- endfor %}
458+
)),
459+
)
460+
{% endif -%}
461+
pager = client.{{ method.name|snake_case }}(request={})
462+
463+
assert pager._metadata == metadata
464+
465+
results = [i for i in pager]
448466
assert len(results) == 6
449467
assert all(isinstance(i, {{ method.paged_result_field.message.ident }})
450468
for i in results)

0 commit comments

Comments
 (0)