Skip to content

Commit

Permalink
fix: only set unset fields if they are query params
Browse files Browse the repository at this point in the history
  • Loading branch information
software-dov committed Jan 12, 2022
1 parent d528223 commit 9e8fc1e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not (method.server_streaming or method.client_streaming) %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value or 0 }}{% endif %},{# default is str #}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #}
{% endfor %}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}
Expand All @@ -1003,23 +1003,27 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
{% endfor %}

{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
{% if method.query_params %}
# Check that path parameters and body parameters are not mixing in.
assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param|camel_case }}", {% endfor %}))
{% endif %}
jsonified_request["{{ field_name }}"] = {{ mock_value }}
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)

# verify required fields with non-default values are left alone
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
assert "{{ field_name }}" in jsonified_request
Expand Down Expand Up @@ -1080,7 +1084,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% endif %}

expected_params = [
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
(
"{{ req_field.name | camel_case }}",
{% if req_field.field_pb.type == 9 %}
Expand All @@ -1095,6 +1099,13 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
assert expected_params == actual_params


def test_{{ method_name }}_rest_unset_required_fields():
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)

unset_fields = transport.{{ method_name|snake_case }}._get_unset_required_fields({})
assert set(unset_fields) == set(({% for param in method.query_params %}"{{ param|camel_case }}", {% endfor %}))


{% endif %}{# required_fields #}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not (method.server_streaming or method.client_streaming) %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #}
{% endfor %}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1287,65 +1287,8 @@ def test_{{ method_name }}_raw_page_lro():
{% endfor %} {# method in methods for grpc #}

{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when streaming are supported. #}
{% if not (method.server_streaming or method.client_streaming) %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
])
def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport="rest",
)
# Send a request that will satisfy transcoding
request = {{ method.input.ident }}({{ method.http_options[0].sample_request(method) }})
{% if method.client_streaming %}
requests = [request]
{% endif %}


with mock.patch.object(type(client.transport._session), 'request') as req:
{% if method.void %}
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}(
{% for field in method.output.fields.values() | rejectattr('message')%}
{% if not field.oneof or field.proto3_optional %}
{{ field.name }}={{ field.mock_value }},
{% endif %}{% endfor %}
{# This is a hack to only pick one field #}
{% for oneof_fields in method.output.oneof_fields().values() %}
{% with field = oneof_fields[0] %}
{{ field.name }}={{ field.mock_value }},
{% endwith %}
{% endfor %}
)
{% endif %}
req.return_value = Response()
req.return_value.status_code = 500
req.return_value.request = PreparedRequest()
{% if method.void %}
json_return_value = ''
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
req.return_value._content = json_return_value.encode("UTF-8")
with pytest.raises(core_exceptions.GoogleAPIError):
# We only care that the correct exception is raised when putting
# the request over the wire, so an empty request is fine.
{% if method.client_streaming %}
client.{{ method_name }}(iter([requests]))
{% else %}
client.{{ method_name }}(request)
{% endif %}


{# TODO(kbandes): remove this if condition when lro and streaming are supported. #}
{% if not (method.server_streaming or method.client_streaming) %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand Down Expand Up @@ -1458,7 +1401,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}
Expand All @@ -1467,7 +1410,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
Expand All @@ -1480,6 +1423,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
{% if method.query_params %}
# Check that path parameters and body parameters are not mixing in.
assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param}}", {% endfor %}))
{% endif %}
jsonified_request.update(unset_fields)

# verify required fields with non-default values are left alone
Expand Down Expand Up @@ -1544,7 +1491,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% endif %}

expected_params = [
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
(
"{{ req_field.name | camel_case }}",
{% if req_field.field_pb.type == 9 %}
Expand All @@ -1559,6 +1506,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
assert expected_params == actual_params


def test_{{ method_name }}_rest_unset_required_fields():
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)

unset_fields = transport.{{ method_name|snake_case }}._get_unset_required_fields({})
assert set(unset_fields) == set(({% for param in method.query_params %}"{{ param|camel_case }}", {% endfor %}))

{% endif %}{# required_fields #}


Expand Down
4 changes: 0 additions & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def __call__(self, frag):
f"--python_gapic_opt=transport=grpc+rest,python-gapic-templates={templates}{maybe_old_naming}",
]

if self.use_ads_templates:
session_args.extend([])

outputs.append(
self.session.run(*session_args, str(frag), external=True, silent=True,)
)
Expand All @@ -114,7 +111,6 @@ def __call__(self, frag):
# Note: install into the tempdir to prevent issues
# with running pip concurrently.
self.session.install(tmp_dir, "-e", ".", "-t", tmp_dir, "-qqq")

# Run the fragment's generated unit tests.
# Don't bother parallelizing them: we already parallelize
# the fragments, and there usually aren't too many tests per fragment.
Expand Down
11 changes: 5 additions & 6 deletions tests/fragments/test_multiple_required_fields.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ service MultipleRequiredFields {
}
}

message Description {
string description = 1;
}

message MethodRequest {
string kingdom = 1 [(google.api.field_behavior) = REQUIRED];
string phylum = 2 [(google.api.field_behavior) = REQUIRED];
Description description = 3 [(google.api.field_behavior) = REQUIRED];
string name = 3 [(google.api.field_behavior) = REQUIRED];
int32 armor_class = 4 [(google.api.field_behavior) = REQUIRED];
}

message MethodResponse{}
message MethodResponse{
string text = 1;
}

0 comments on commit 9e8fc1e

Please sign in to comment.