From c43c6d943fa99f202014bf4bba795df25d314a63 Mon Sep 17 00:00:00 2001 From: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Date: Tue, 7 Jul 2020 15:29:39 -0700 Subject: [PATCH] fix: pass metadata to pagers (#470) Closes #469 --- gapic/schema/wrappers.py | 5 +++ .../%sub/services/%service/async_client.py.j2 | 1 + .../%sub/services/%service/client.py.j2 | 1 + .../%sub/services/%service/pagers.py.j2 | 26 +++++++++----- .../%name_%version/%sub/test_%service.py.j2 | 24 +++++++++++-- tests/unit/schema/wrappers/test_service.py | 35 +++++++++++++++++++ 6 files changed, 80 insertions(+), 12 deletions(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 5a632fcc02..6f7e041f16 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -832,6 +832,11 @@ def has_lro(self) -> bool: """Return whether the service has a long-running method.""" return any([m.lro for m in self.methods.values()]) + @property + def has_pagers(self) -> bool: + """Return whether the service has paged methods.""" + return any(m.paged_result_field for m in self.methods.values()) + @property def host(self) -> str: """Return the hostname for this service, if specified. diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 index fb501fe2bc..36a34471f8 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 @@ -244,6 +244,7 @@ class {{ service.async_client_name }}: method=rpc, request=request, response=response, + metadata=metadata, ) {%- endif %} {%- if not method.void %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 9b92f6e6fc..c34babd763 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -374,6 +374,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): method=rpc, request=request, response=response, + metadata=metadata, ) {%- endif %} {%- if not method.void %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 index 5c069b68fd..cc7bc56100 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 @@ -6,7 +6,7 @@ {# This lives within the loop in order to ensure that this template is empty if there are no paged methods. -#} -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple {% filter sort_lines -%} {% for method in service.methods.values() | selectattr('paged_result_field') -%} @@ -35,10 +35,11 @@ class {{ method.name }}Pager: the most recent response is retained, and thus used for attribute lookup. """ def __init__(self, - method: Callable[[{{ method.input.ident }}], - {{ method.output.ident }}], + method: Callable[..., {{ method.output.ident }}], request: {{ method.input.ident }}, - response: {{ method.output.ident }}): + response: {{ method.output.ident }}, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -48,10 +49,13 @@ class {{ method.name }}Pager: The initial request object. response (:class:`{{ method.output.ident.sphinx }}`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = {{ method.input.ident }}(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -61,7 +65,7 @@ class {{ method.name }}Pager: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}: @@ -90,10 +94,11 @@ class {{ method.name }}AsyncPager: the most recent response is retained, and thus used for attribute lookup. """ def __init__(self, - method: Callable[[{{ method.input.ident }}], - Awaitable[{{ method.output.ident }}]], + method: Callable[..., Awaitable[{{ method.output.ident }}]], request: {{ method.input.ident }}, - response: {{ method.output.ident }}): + response: {{ method.output.ident }}, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -103,10 +108,13 @@ class {{ method.name }}AsyncPager: The initial request object. response (:class:`{{ method.output.ident.sphinx }}`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = {{ method.input.ident }}(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -116,7 +124,7 @@ class {{ method.name }}AsyncPager: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request) + self._response = await self._method(self._request, metadata=self._metadata) yield self._response def __aiter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'AsyncIterable') }}: diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index ecfe6af704..c21846a4ac 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -27,6 +27,9 @@ from google.api_core import future from google.api_core import operations_v1 from google.longrunning import operations_pb2 {% endif -%} +{% if service.has_pagers -%} +from google.api_core import gapic_v1 +{% endif -%} {% for method in service.methods.values() -%} {% for ref_type in method.ref_types if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation') @@ -695,9 +698,24 @@ def test_{{ method.name|snake_case }}_pager(): ), RuntimeError, ) - results = [i for i in client.{{ method.name|snake_case }}( - request={}, - )] + + metadata = () + {% if method.field_headers -%} + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + {%- for field_header in method.field_headers %} + {%- if not method.client_streaming %} + ('{{ field_header }}', ''), + {%- endif %} + {%- endfor %} + )), + ) + {% endif -%} + pager = client.{{ method.name|snake_case }}(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] assert len(results) == 6 assert all(isinstance(i, {{ method.paged_result_field.message.ident }}) for i in results) diff --git a/tests/unit/schema/wrappers/test_service.py b/tests/unit/schema/wrappers/test_service.py index 86d1aa2e97..8502617b5d 100644 --- a/tests/unit/schema/wrappers/test_service.py +++ b/tests/unit/schema/wrappers/test_service.py @@ -260,3 +260,38 @@ def test_service_any_streaming(): assert service.any_client_streaming == client assert service.any_server_streaming == server + + +def test_has_pagers(): + paged = make_field(name='foos', message=make_message('Foo'), repeated=True) + input_msg = make_message( + name='ListFoosRequest', + fields=( + make_field(name='parent', type=9), # str + make_field(name='page_size', type=5), # int + make_field(name='page_token', type=9), # str + ), + ) + output_msg = make_message( + name='ListFoosResponse', + fields=( + paged, + make_field(name='next_page_token', type=9), # str + ), + ) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + + service = make_service(name="Fooer", methods=(method,),) + assert service.has_pagers + + other_service = make_service( + name="Unfooer", + methods=( + get_method("Unfoo", "foo.bar.UnfooReq", "foo.bar.UnFooResp"), + ), + ) + assert not other_service.has_pagers