Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: only set unset fields if they are query params #1130

Merged
merged 3 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not (method.server_streaming or method.client_streaming) %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value or 0 }}{% endif %},{# default is str #}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #}
{% endfor %}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}
Expand All @@ -1003,23 +1003,32 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
{% endfor %}

{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
{% if method.query_params %}
# Check that path parameters and body parameters are not mixing in.
assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param|camel_case }}", {% endfor %}))
{% endif %}
jsonified_request["{{ field_name }}"] = {{ mock_value }}
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
{% if method.query_params %}
# Check that path parameters and body parameters are not mixing in.
assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param}}",
{% endfor %}))
{% endif %}
jsonified_request.update(unset_fields)

# verify required fields with non-default values are left alone
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
assert "{{ field_name }}" in jsonified_request
Expand Down Expand Up @@ -1080,7 +1089,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% endif %}

expected_params = [
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
(
"{{ req_field.name | camel_case }}",
{% if req_field.field_pb.type == 9 %}
Expand All @@ -1095,6 +1104,13 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
assert expected_params == actual_params


def test_{{ method_name }}_rest_unset_required_fields():
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)

unset_fields = transport.{{ method.name|snake_case }}._get_unset_required_fields({})
assert set(unset_fields) == (set(({% for param in method.query_params %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{param.name|camel_case}}", {% endfor %})))


{% endif %}{# required_fields #}


Expand Down
2 changes: 1 addition & 1 deletion gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ def path_params(self) -> Sequence[str]:
if self.http_opt is None:
return []

pattern = r'\{(\w+)\}'
pattern = r'\{(\w+)(?:=.+?)?\}'
return re.findall(pattern, self.http_opt['url'])

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not (method.server_streaming or method.client_streaming) %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #}
{% endfor %}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1287,65 +1287,8 @@ def test_{{ method_name }}_raw_page_lro():
{% endfor %} {# method in methods for grpc #}

{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when streaming are supported. #}
{% if not (method.server_streaming or method.client_streaming) %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
])
def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
busunkim96 marked this conversation as resolved.
Show resolved Hide resolved
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport="rest",
)
# Send a request that will satisfy transcoding
request = {{ method.input.ident }}({{ method.http_options[0].sample_request(method) }})
{% if method.client_streaming %}
requests = [request]
{% endif %}


with mock.patch.object(type(client.transport._session), 'request') as req:
{% if method.void %}
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}(
{% for field in method.output.fields.values() | rejectattr('message')%}
{% if not field.oneof or field.proto3_optional %}
{{ field.name }}={{ field.mock_value }},
{% endif %}{% endfor %}
{# This is a hack to only pick one field #}
{% for oneof_fields in method.output.oneof_fields().values() %}
{% with field = oneof_fields[0] %}
{{ field.name }}={{ field.mock_value }},
{% endwith %}
{% endfor %}
)
{% endif %}
req.return_value = Response()
req.return_value.status_code = 500
req.return_value.request = PreparedRequest()
{% if method.void %}
json_return_value = ''
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
req.return_value._content = json_return_value.encode("UTF-8")
with pytest.raises(core_exceptions.GoogleAPIError):
# We only care that the correct exception is raised when putting
# the request over the wire, so an empty request is fine.
{% if method.client_streaming %}
client.{{ method_name }}(iter([requests]))
{% else %}
client.{{ method_name }}(request)
{% endif %}


{# TODO(kbandes): remove this if condition when lro and streaming are supported. #}
{% if not (method.server_streaming or method.client_streaming) %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand Down Expand Up @@ -1458,7 +1401,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}
Expand All @@ -1467,7 +1410,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
Expand All @@ -1480,6 +1423,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
{% if method.query_params %}
# Check that path parameters and body parameters are not mixing in.
assert not set(unset_fields) - set(({% for param in method.query_params %}"{{param}}", {% endfor %}))
{% endif %}
jsonified_request.update(unset_fields)

# verify required fields with non-default values are left alone
Expand Down Expand Up @@ -1544,7 +1491,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% endif %}

expected_params = [
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
(
"{{ req_field.name | camel_case }}",
{% if req_field.field_pb.type == 9 %}
Expand All @@ -1559,6 +1506,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
assert expected_params == actual_params


def test_{{ method_name }}_rest_unset_required_fields():
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)

unset_fields = transport.{{ method.name|snake_case }}._get_unset_required_fields({})
assert set(unset_fields) == (set(({% for param in method.query_params %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{ param.name|camel_case }}", {% endfor %})))

{% endif %}{# required_fields #}


Expand Down
4 changes: 0 additions & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def __call__(self, frag):
f"--python_gapic_opt=transport=grpc+rest,python-gapic-templates={templates}{maybe_old_naming}",
]

if self.use_ads_templates:
session_args.extend([])

outputs.append(
self.session.run(*session_args, str(frag), external=True, silent=True,)
)
Expand All @@ -114,7 +111,6 @@ def __call__(self, frag):
# Note: install into the tempdir to prevent issues
# with running pip concurrently.
self.session.install(tmp_dir, "-e", ".", "-t", tmp_dir, "-qqq")

# Run the fragment's generated unit tests.
# Don't bother parallelizing them: we already parallelize
# the fragments, and there usually aren't too many tests per fragment.
Expand Down
11 changes: 5 additions & 6 deletions tests/fragments/test_multiple_required_fields.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ service MultipleRequiredFields {
}
}

message Description {
string description = 1;
}

message MethodRequest {
string kingdom = 1 [(google.api.field_behavior) = REQUIRED];
string phylum = 2 [(google.api.field_behavior) = REQUIRED];
Description description = 3 [(google.api.field_behavior) = REQUIRED];
string name = 3 [(google.api.field_behavior) = REQUIRED];
int32 armor_class = 4 [(google.api.field_behavior) = REQUIRED];
}

message MethodResponse{}
message MethodResponse{
string text = 1;
}
4 changes: 4 additions & 0 deletions tests/unit/schema/wrappers/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def test_method_path_params():
method = make_method('DoSomething', http_rule=http_rule)
assert method.path_params == ['project']

http_rule2 = http_pb2.HttpRule(post='/v1beta1/{name=rooms/*/blurbs/*}')
method2 = make_method("DoSomething", http_rule=http_rule2)
assert method2.path_params == ["name"]


def test_method_path_params_no_http_rule():
method = make_method('DoSomething')
Expand Down