From 09692c4e889ccde3b0ca31a5e8476c1679804beb Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Mon, 19 Oct 2020 14:47:22 -0700 Subject: [PATCH] fix: numerous small fixes to allow bigtable-admin (#660) Includes: * tweaked logic around defining recursive message types * more sophisticated logic for generating unit tests using recursive message types * flattened map-y fields are handled properly * fixed a corner case where a method has a third-party request object and flattened fields --- .../%sub/services/%service/client.py.j2 | 14 ++++++--- .../%name_%version/%sub/test_%service.py.j2 | 4 +++ gapic/schema/metadata.py | 15 ++++++---- gapic/schema/wrappers.py | 13 ++++---- .../%sub/services/%service/async_client.py.j2 | 19 ++++++++---- .../%sub/services/%service/client.py.j2 | 16 ++++++---- .../%name_%version/%sub/test_%service.py.j2 | 17 ++++++++--- tests/unit/schema/test_metadata.py | 2 +- tests/unit/schema/wrappers/test_field.py | 30 +++++++++++++++++++ 9 files changed, 96 insertions(+), 34 deletions(-) diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 index 67d5de4680..5f12de323c 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 @@ -333,7 +333,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): request = {{ method.input.ident }}(**request) {% if method.flattened_fields -%}{# Cross-package req and flattened fields #} elif not request: - request = {{ method.input.ident }}() + request = {{ method.input.ident }}({% if method.input.ident.package != method.ident.package %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}, {% endfor %}{% endif %}) {% endif -%}{# Cross-package req and flattened fields #} {%- else %} # Minor optimization to avoid making a copy if the user passes @@ -344,7 +344,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): request = {{ method.input.ident }}(request) {% endif %} {# different request package #} {#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #} - {% if method.flattened_fields -%} + {% if method.flattened_fields and method.input.ident.package == method.ident.package -%} # If we have keyword arguments corresponding to fields on the # request, apply these. {% endif -%} @@ -352,8 +352,14 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): if {{ field.name }} is not None: request.{{ key }} = {{ field.name }} {%- endfor %} - {# They can be _extended_, however -#} - {%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %} + {# Map-y fields can be _updated_, however #} + {%- for key, field in method.flattened_fields.items() if field.map and method.input.ident.package == method.ident.package %} + + if {{ field.name }}: + request.{{ key }}.update({{ field.name }}) + {%- endfor %} + {# And list-y fields can be _extended_ -#} + {%- for key, field in method.flattened_fields.items() if field.repeated and not field.map and method.input.ident.package == method.ident.package %} if {{ field.name }}: request.{{ key }}.extend({{ field.name }}) {%- endfor %} diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 810a6f2e92..14687174cd 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -297,6 +297,10 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ m for message in response: assert isinstance(message, {{ method.output.ident }}) {% else -%} + {% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %} + {# Cheeser assertion to force code coverage for bad paginated methods #} + assert response.raw_page is response + {% endif %} assert isinstance(response, {{ method.client_output.ident }}) {% for field in method.output.fields.values() | rejectattr('message') -%}{% if not field.oneof or field.proto3_optional %} {% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#} diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index b801bb7603..4e78119f91 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -242,12 +242,15 @@ def rel(self, address: 'Address') -> str: # It is possible that a field references a message that has # not yet been declared. If so, send its name enclosed in quotes # (a string) instead. - if self.module_path > address.module_path or self == address: - return f"'{'.'.join(self.parent + (self.name,))}'" - - # This is a message in the same module, already declared. - # Send its name. - return '.'.join(self.parent + (self.name,)) + # + # Note: this is a conservative construction; it generates a stringy + # identifier all the time when it may be possible to use a regular + # module lookup. + # On the other hand, there's no reason _not_ to use a stringy + # identifier. It is guaranteed to work all the time because + # it bumps name resolution until a time when all types in a module + # are guaranteed to be fully defined. + return f"'{'.'.join(self.parent + (self.name,))}'" # Return the usual `module.Name`. return str(self) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 1b0db83e5f..62c36270bf 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -55,9 +55,6 @@ class Field: ) oneof: Optional[str] = None - # Arbitrary cap set via heuristic rule of thumb. - MAX_MOCK_DEPTH: int = 20 - def __getattr__(self, name): return getattr(self.field_pb, name) @@ -93,17 +90,16 @@ def map(self) -> bool: @utils.cached_property def mock_value(self) -> str: - depth = 0 + visited_fields: Set["Field"] = set() stack = [self] answer = "{}" while stack: expr = stack.pop() - answer = answer.format(expr.inner_mock(stack, depth)) - depth += 1 + answer = answer.format(expr.inner_mock(stack, visited_fields)) return answer - def inner_mock(self, stack, depth): + def inner_mock(self, stack, visited_fields): """Return a repr of a valid, usually truthy mock value.""" # For primitives, send a truthy value computed from the # field name. @@ -137,10 +133,11 @@ def inner_mock(self, stack, depth): and isinstance(self.type, MessageType) and len(self.type.fields) # Nested message types need to terminate eventually - and depth < self.MAX_MOCK_DEPTH + and self not in visited_fields ): sub = next(iter(self.type.fields.values())) stack.append(sub) + visited_fields.add(self) # Don't do the recursive rendering here, just set up # where the nested value should go with the double {}. answer = f'{self.type.ident}({sub.name}={{}})' diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 index 0f2e88700f..9d5150d869 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 @@ -169,7 +169,8 @@ class {{ service.async_client_name }}: {% if method.flattened_fields -%} # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]): + has_flattened_params = any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]) + if request is not None and has_flattened_params: raise ValueError('If the `request` argument is set, then none of ' 'the individual field arguments should be set.') @@ -181,23 +182,29 @@ class {{ service.async_client_name }}: request = {{ method.input.ident }}(**request) {% if method.flattened_fields -%}{# Cross-package req and flattened fields #} elif not request: - request = {{ method.input.ident }}() + request = {{ method.input.ident }}({% if method.input.ident.package != method.ident.package %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}, {% endfor %}{% endif %}) {% endif -%}{# Cross-package req and flattened fields #} {%- else %} request = {{ method.input.ident }}(request) {% endif %} {# different request package #} {#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #} - {% if method.flattened_fields -%} + {% if method.flattened_fields and method.input.ident.package == method.ident.package -%} # If we have keyword arguments corresponding to fields on the # request, apply these. {% endif -%} - {%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %} + {%- for key, field in method.flattened_fields.items() if not field.repeated and method.input.ident.package == method.ident.package %} if {{ field.name }} is not None: request.{{ key }} = {{ field.name }} {%- endfor %} - {# They can be _extended_, however -#} - {%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %} + {# Map-y fields can be _updated_, however #} + {%- for key, field in method.flattened_fields.items() if field.map and method.input.ident.package == method.ident.package %} + + if {{ field.name }}: + request.{{ key }}.update({{ field.name }}) + {%- endfor %} + {# And list-y fields can be _extended_ -#} + {%- for key, field in method.flattened_fields.items() if field.repeated and not field.map and method.input.ident.package == method.ident.package %} if {{ field.name }}: request.{{ key }}.extend({{ field.name }}) {%- endfor %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index c8040284d5..c3093aa1cf 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -345,7 +345,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): request = {{ method.input.ident }}(**request) {% if method.flattened_fields -%}{# Cross-package req and flattened fields #} elif not request: - request = {{ method.input.ident }}() + request = {{ method.input.ident }}({% if method.input.ident.package != method.ident.package %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}, {% endfor %}{% endif %}) {% endif -%}{# Cross-package req and flattened fields #} {%- else %} # Minor optimization to avoid making a copy if the user passes @@ -357,16 +357,22 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): {% endif %} {# different request package #} {#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #} - {% if method.flattened_fields -%} + {% if method.flattened_fields and method.input.ident.package == method.ident.package -%} # If we have keyword arguments corresponding to fields on the # request, apply these. {% endif -%} - {%- for key, field in method.flattened_fields.items() if not(field.repeated or method.input.ident.package != method.ident.package) %} + {%- for key, field in method.flattened_fields.items() if not field.repeated and method.input.ident.package == method.ident.package %} if {{ field.name }} is not None: request.{{ key }} = {{ field.name }} {%- endfor %} - {# They can be _extended_, however -#} - {%- for key, field in method.flattened_fields.items() if field.repeated %} + {# Map-y fields can be _updated_, however #} + {%- for key, field in method.flattened_fields.items() if field.map and method.input.ident.package == method.ident.package %} + + if {{ field.name }}: + request.{{ key }}.update({{ field.name }}) + {%- endfor %} + {# And list-y fields can be _extended_ -#} + {%- for key, field in method.flattened_fields.items() if field.repeated and not field.map and method.input.ident.package == method.ident.package %} if {{ field.name }}: request.{{ key }}.extend({{ field.name }}) {%- endfor %} diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 359b548a0b..dd4fd637ec 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -398,6 +398,10 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ m for message in response: assert isinstance(message, {{ method.output.ident }}) {% else -%} + {% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %} + {# Cheeser assertion to force code coverage for bad paginated methods #} + assert response.raw_page is response + {% endif %} assert isinstance(response, {{ method.client_output.ident }}) {% for field in method.output.fields.values() | rejectattr('message') -%}{% if not field.oneof or field.proto3_optional %} {% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#} @@ -417,7 +421,7 @@ def test_{{ method.name|snake_case }}_from_dict(): @pytest.mark.asyncio -async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio'): +async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio', request_type={{ method.input.ident }}): client = {{ service.async_client_name }}( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -425,7 +429,7 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = {{ method.input.ident }}() + request = request_type() {% if method.client_streaming %} requests = [request] {% endif %} @@ -474,7 +478,7 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio {% if method.client_streaming %} assert next(args[0]) == request {% else %} - assert args[0] == request + assert args[0] == {{ method.input.ident }}() {% endif %} # Establish that the response is the type that we expect. @@ -500,6 +504,11 @@ async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio {% endif %} +@pytest.mark.asyncio +async def test_{{ method.name|snake_case }}_async_from_dict(): + await test_{{ method.name|snake_case }}_async(request_type=dict) + + {% if method.field_headers and not method.client_streaming %} def test_{{ method.name|snake_case }}_field_headers(): client = {{ service.client_name }}( @@ -592,7 +601,7 @@ async def test_{{ method.name|snake_case }}_field_headers_async(): {% endif %} {% if method.ident.package != method.input.ident.package %} -def test_{{ method.name|snake_case }}_from_dict(): +def test_{{ method.name|snake_case }}_from_dict_foreign(): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), ) diff --git a/tests/unit/schema/test_metadata.py b/tests/unit/schema/test_metadata.py index 62cd957cbd..4be166bc1b 100644 --- a/tests/unit/schema/test_metadata.py +++ b/tests/unit/schema/test_metadata.py @@ -70,7 +70,7 @@ def test_address_rel(): addr = metadata.Address(package=('foo', 'bar'), module='baz', name='Bacon') assert addr.rel( metadata.Address(package=('foo', 'bar'), module='baz'), - ) == 'Bacon' + ) == "'Bacon'" def test_address_rel_other(): diff --git a/tests/unit/schema/wrappers/test_field.py b/tests/unit/schema/wrappers/test_field.py index bba280b8f5..99f2edc9f7 100644 --- a/tests/unit/schema/wrappers/test_field.py +++ b/tests/unit/schema/wrappers/test_field.py @@ -19,6 +19,7 @@ from google.api import field_behavior_pb2 from google.protobuf import descriptor_pb2 +from gapic.schema import api from gapic.schema import metadata from gapic.schema import wrappers @@ -250,6 +251,35 @@ def test_mock_value_message(): assert field.mock_value == 'bogus.Message(foo=324)' +def test_mock_value_recursive(): + # The elaborate setup is an unfortunate requirement. + file_pb = descriptor_pb2.FileDescriptorProto( + name="turtle.proto", + package="animalia.chordata.v2", + message_type=( + descriptor_pb2.DescriptorProto( + # It's turtles all the way down ;) + name="Turtle", + field=( + descriptor_pb2.FieldDescriptorProto( + name="turtle", + type="TYPE_MESSAGE", + type_name=".animalia.chordata.v2.Turtle", + number=1, + ), + ), + ), + ), + ) + my_api = api.API.build([file_pb], package="animalia.chordata.v2") + turtle_field = my_api.messages["animalia.chordata.v2.Turtle"].fields["turtle"] + + # If not handled properly, this will run forever and eventually OOM. + actual = turtle_field.mock_value + expected = "ac_turtle.Turtle(turtle=ac_turtle.Turtle(turtle=turtle.Turtle(turtle=None)))" + assert actual == expected + + def test_field_name_kword_disambiguation(): from_field = make_field( name="from",