From 36e3236b3832993331d8d99c10e72797a8851390 Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Thu, 30 Dec 2021 12:22:28 -0800 Subject: [PATCH] fix: handle message bodies (#1117) Some methods with http annotations have body fields that are message types. Previously, generated unit tests did not handle this well. This is a fix for that: generate a reasonable mock value for the message, represented as a dict. --- gapic/schema/wrappers.py | 17 +++++- tests/fragments/test_non_primitive_body.proto | 53 +++++++++++++++++++ tests/unit/schema/wrappers/test_field.py | 30 ++++++++++- 3 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 tests/fragments/test_non_primitive_body.proto diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index aa474e0e4e..c4c9e6bec0 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -94,7 +94,20 @@ def map(self) -> bool: return bool(self.repeated and self.message and self.message.map) @utils.cached_property - def mock_value_original_type(self) -> Union[bool, str, bytes, int, float, List[Any], None]: + def mock_value_original_type(self) -> Union[bool, str, bytes, int, float, Dict[str, Any], List[Any], None]: + # Return messages as dicts and let the message ctor handle the conversion. + if self.message: + if self.map: + # Not worth the hassle, just return an empty map. + return {} + + msg_dict = { + f.name: f.mock_value_original_type + for f in self.message.fields.values() + } + + return [msg_dict] if self.repeated else msg_dict + answer = self.primitive_mock() or None # If this is a repeated field, then the mock answer should @@ -173,7 +186,7 @@ def primitive_mock(self, suffix: int = 0) -> Union[bool, str, bytes, int, float, answer: Union[bool, str, bytes, int, float, List[Any], None] = None if not isinstance(self.type, PrimitiveType): - raise TypeError(f"'inner_mock_as_original_type' can only be used for" + raise TypeError(f"'primitive_mock' can only be used for " f"PrimitiveType, but type is {self.type}") else: diff --git a/tests/fragments/test_non_primitive_body.proto b/tests/fragments/test_non_primitive_body.proto new file mode 100644 index 0000000000..f322a747a4 --- /dev/null +++ b/tests/fragments/test_non_primitive_body.proto @@ -0,0 +1,53 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.fragment; + +import "google/api/client.proto"; +import "google/api/annotations.proto"; +import "google/api/field_behavior.proto"; + +service SmallCompute { + option (google.api.default_host) = "my.example.com"; + + rpc MyMethod(MethodRequest) returns (MethodResponse) { + option (google.api.http) = { + body: "method_body" + post: "/computation/v1/first_name/{first_name}/last_name/{last_name}" + }; + }; +} + +message SerialNumber { + int32 number = 1; +} + +message MethodRequest { + message MethodBody { + int32 mass_kg = 1; + int32 length_cm = 2; + repeated SerialNumber serial_numbers = 3; + map word_associations = 4; + } + + string first_name = 1 [(google.api.field_behavior) = REQUIRED]; + string last_name = 2 [(google.api.field_behavior) = REQUIRED]; + MethodBody method_body = 3 [(google.api.field_behavior) = REQUIRED]; +} + +message MethodResponse { + string name = 1; +} \ No newline at end of file diff --git a/tests/unit/schema/wrappers/test_field.py b/tests/unit/schema/wrappers/test_field.py index f823104e77..151b2762b8 100644 --- a/tests/unit/schema/wrappers/test_field.py +++ b/tests/unit/schema/wrappers/test_field.py @@ -241,6 +241,7 @@ def test_mock_value_map(): label=3, type='TYPE_MESSAGE', ) + assert field.mock_value == "{'key_value': 'value_value'}" @@ -290,7 +291,7 @@ def test_mock_value_message(): assert field.mock_value == 'bogus.Message(foo=324)' -def test_mock_value_original_type_message_errors(): +def test_mock_value_original_type_message(): subfields = collections.OrderedDict(( ('foo', make_field(name='foo', type='TYPE_INT32')), ('bar', make_field(name='bar', type='TYPE_STRING')) @@ -307,14 +308,39 @@ def test_mock_value_original_type_message_errors(): nested_enums={}, nested_messages={}, ) + field = make_field( type='TYPE_MESSAGE', type_name='bogus.Message', message=message, ) + mock = field.mock_value_original_type + + assert mock == {"foo": 324, "bar": "bar_value"} + + # Messages by definition aren't primitive with pytest.raises(TypeError): - mock = field.mock_value_original_type + field.primitive_mock() + + # Special case for map entries + entry_msg = make_message( + name='MessageEntry', + fields=( + make_field(name='key', type='TYPE_STRING'), + make_field(name='value', type='TYPE_STRING'), + ), + options=descriptor_pb2.MessageOptions(map_entry=True), + ) + entry_field = make_field( + name="messages", + type_name="stuff.MessageEntry", + message=entry_msg, + label=3, + type='TYPE_MESSAGE', + ) + + assert entry_field.mock_value_original_type == {} def test_mock_value_recursive():