diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index eefe0cdc7e..812630720b 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -866,13 +866,22 @@ def paged_result_field(self) -> Optional[Field]: """Return the response pagination field if the method is paginated.""" # If the request field lacks any of the expected pagination fields, # then the method is not paginated. - for page_field in ((self.input, int, 'page_size'), - (self.input, str, 'page_token'), + + # The request must have page_token and next_page_token as they keep track of pages + for source, source_type, name in ((self.input, str, 'page_token'), (self.output, str, 'next_page_token')): - field = page_field[0].fields.get(page_field[2], None) - if not field or field.type != page_field[1]: + field = source.fields.get(name, None) + if not field or field.type != source_type: return None + # The request must have max_results or page_size + page_fields = (self.input.fields.get('max_results', None), + self.input.fields.get('page_size', None)) + page_field_size = next( + (field for field in page_fields if field), None) + if not page_field_size or page_field_size.type != int: + return None + # Return the first repeated field. for field in self.output.fields.values(): if field.repeated: diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 index ea08466ba0..ca3cc8d40e 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 @@ -6,7 +6,7 @@ {# This lives within the loop in order to ensure that this template is empty if there are no paged methods. -#} -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional {% filter sort_lines -%} {% for method in service.methods.values() | selectattr('paged_result_field') -%} @@ -68,14 +68,25 @@ class {{ method.name }}Pager: self._response = self._method(self._request, metadata=self._metadata) yield self._response + {% if method.paged_result_field.map %} + def __iter__(self) -> Iterable[Tuple[str, {{ method.paged_result_field.type.fields.get('value').ident }}]]: + for page in self.pages: + yield from page.{{ method.paged_result_field.name}}.items() + + def get(self, key: str) -> Optional[{{ method.paged_result_field.type.fields.get('value').ident }}]: + return self._response.items.get(key) + {% else %} def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}: for page in self.pages: yield from page.{{ method.paged_result_field.name }} + {% endif %} def __repr__(self) -> str: return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) +{# TODO(yon-mg): remove on rest async transport impl #} +{% if 'grpc' in opts.transport %} class {{ method.name }}AsyncPager: """A pager for iterating through ``{{ method.name|snake_case }}`` requests. @@ -138,5 +149,6 @@ class {{ method.name }}AsyncPager: def __repr__(self) -> str: return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) +{% endif %} {% endfor %} {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 7e84b78a99..060e5d0744 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -184,11 +184,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - {% filter sort_lines -%} - {%- for field in method.query_params %} + {%- for field in method.query_params | sort%} '{{ field|camel_case }}': request.{{ field }}, {%- endfor %} - {% endfilter -%} } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index e1c42f89a2..59611cfd33 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1020,7 +1020,7 @@ def test_{{ method.name|snake_case }}_raw_page_lro(): assert response.raw_page is response {% endif %} {#- method.paged_result_field #} -{% endfor -%} {#- method in methods #} +{% endfor -%} {#- method in methods for grpc #} {% for method in service.methods.values() if 'rest' in opts.transport -%} def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type={{ method.input.ident }}): @@ -1162,7 +1162,126 @@ def test_{{ method.name|snake_case }}_rest_flattened_error(): ) -{% endfor -%} +{% if method.paged_result_field %} +def test_{{ method.name|snake_case }}_pager(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # Set the response as a series of pages + {% if method.paged_result_field.map%} + response = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'a':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'b':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'c':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={}, + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'g':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'h':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'i':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + ), + ) + {% else %} + response = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + ], + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[], + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + ], + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + ], + ), + ) + {% endif %} + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple({{ method.output.ident }}.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode('UTF-8') + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + {% if method.field_headers -%} + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + {%- for field_header in method.field_headers %} + {%- if not method.client_streaming %} + ('{{ field_header }}', ''), + {%- endif %} + {%- endfor %} + )), + ) + {% endif -%} + pager = client.{{ method.name|snake_case }}(request={}) + + assert pager._metadata == metadata + + {% if method.paged_result_field.map %} + assert isinstance(pager.get('a'), {{ method.paged_result_field.type.fields.get('value').ident }}) + assert pager.get('h') is None + {% endif %} + + results = list(pager) + assert len(results) == 6 + {% if method.paged_result_field.map %} + assert all( + isinstance(i, tuple) + for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }}) + + assert pager.get('a') is None + assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }}) + {% else %} + assert all(isinstance(i, {{ method.paged_result_field.type.ident }}) + for i in results) + {% endif %} + + pages = list(client.{{ method.name|snake_case }}(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +{% endif %} {# paged methods #} +{% endfor -%} {#- method in methods for rest #} def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index bcaeb68800..2162effbbb 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -66,19 +66,38 @@ def test_method_client_output_empty(): def test_method_client_output_paged(): paged = make_field(name='foos', message=make_message('Foo'), repeated=True) + parent = make_field(name='parent', type=9) # str + page_size = make_field(name='page_size', type=5) # int + page_token = make_field(name='page_token', type=9) # str + input_msg = make_message(name='ListFoosRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str + parent, + page_size, + page_token, )) output_msg = make_message(name='ListFoosResponse', fields=( paged, make_field(name='next_page_token', type=9), # str )) - method = make_method('ListFoos', - input_message=input_msg, - output_message=output_msg, - ) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListFoosPager' + + max_results = make_field(name='max_results', type=5) # int + input_msg = make_message(name='ListFoosRequest', fields=( + parent, + max_results, + page_token, + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) assert method.paged_result_field == paged assert method.client_output.ident.name == 'ListFoosPager' @@ -123,6 +142,19 @@ def test_method_paged_result_field_no_page_field(): ) assert method.paged_result_field is None + method = make_method( + name='Foo', + input_message=make_message( + name='FooRequest', + fields=(make_field(name='page_token', type=9),) # str + ), + output_message=make_message( + name='FooResponse', + fields=(make_field(name='next_page_token', type=9),) # str + ) + ) + assert method.paged_result_field is None + def test_method_paged_result_ref_types(): input_msg = make_message( @@ -139,7 +171,7 @@ def test_method_paged_result_ref_types(): name='ListMolluscsResponse', fields=( make_field(name='molluscs', message=mollusc_msg, repeated=True), - make_field(name='next_page_token', type=9) + make_field(name='next_page_token', type=9) # str ), module='mollusc' ) @@ -207,7 +239,7 @@ def test_flattened_ref_types(): def test_method_paged_result_primitive(): - paged = make_field(name='squids', type=9, repeated=True) + paged = make_field(name='squids', type=9, repeated=True) # str input_msg = make_message( name='ListSquidsRequest', fields=(