Skip to content

Commit

Permalink
fix: fix tests generation logic (#1049)
Browse files Browse the repository at this point in the history
* fix: fix tests generation logic

This includes:
1) Fix test logic for grpc+rest case, when clients with both transports need to be initialized in parametrized tests
2) Fix 100% coverage problem for rest clients, when the http error (>= 400 error code) case logic was not covered.

* fix integration testrs
  • Loading branch information
vam-google authored Oct 29, 2021
1 parent 6b640af commit 8f213ad
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from proto.marshal.rules.dates import DurationRule, TimestampRule

{% if 'rest' in opts.transport %}
from requests import Response
from requests import Request
from requests.sessions import Session
{% endif %}

Expand Down Expand Up @@ -104,7 +105,8 @@ def test_{{ service.client_name|snake_case }}_from_service_account_info(client_c
{% if 'grpc' in opts.transport %}
(transports.{{ service.grpc_transport_name }}, "grpc"),
(transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
{% elif 'rest' in opts.transport %}
{% endif %}
{% if 'rest' in opts.transport %}
(transports.{{ service.rest_transport_name }}, "rest"),
{% endif %}
])
Expand Down Expand Up @@ -160,7 +162,8 @@ def test_{{ service.client_name|snake_case }}_get_transport_class():
{% if 'grpc' in opts.transport %}
({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"),
({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
{% elif 'rest' in opts.transport %}
{% endif %}
{% if 'rest' in opts.transport %}
({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"),
{% endif %}
])
Expand All @@ -186,7 +189,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -203,7 +206,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -220,7 +223,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -247,7 +250,7 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
options = client_options.ClientOptions(quota_project_id="octopus")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -265,7 +268,8 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans
({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio", "true"),
({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc", "false"),
({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio", "false"),
{% elif 'rest' in opts.transport %}
{% endif %}
{% if 'rest' in opts.transport %}
({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest", "true"),
({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest", "false"),
{% endif %}
Expand All @@ -285,7 +289,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)

if use_client_cert_env == "false":
expected_client_cert_source = None
Expand Down Expand Up @@ -319,7 +323,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
expected_client_cert_source = client_cert_source_callback

patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -336,7 +340,7 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
with mock.patch.object(transport_class, '__init__') as patched:
with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False):
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -353,7 +357,8 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp
{% if 'grpc' in opts.transport %}
({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"),
({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
{% elif 'rest' in opts.transport %}
{% endif %}
{% if 'rest' in opts.transport %}
({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"),
{% endif %}
])
Expand All @@ -364,7 +369,7 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class
)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -380,7 +385,8 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class
{% if 'grpc' in opts.transport %}
({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"),
({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"),
{% elif 'rest' in opts.transport %}
{% endif %}
{% if 'rest' in opts.transport %}
({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"),
{% endif %}
])
Expand All @@ -391,7 +397,7 @@ def test_{{ service.client_name|snake_case }}_client_options_credentials_file(cl
)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
Expand Down Expand Up @@ -1182,14 +1188,48 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type
{% endif %}


def test_{{ method.name|snake_case }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# send a request that will satisfy transcoding
request_init = {{ method.http_options[0].sample_request}}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.mock_value }}
{% endif %}
{% endfor %}
request = request_type(request_init)
{% if method.client_streaming %}
requests = [request]
{% endif %}

# Mock the http request call within the method and fake a BadRequest error.
with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest):
# Wrap the value into a proper Response obj
response_value = Response()
response_value.status_code = 400
response_value.request = Request()
req.return_value = response_value
{% if method.client_streaming %}
client.{{ method.name|snake_case }}(iter(requests))
{% else %}
client.{{ method.name|snake_case }}(request)
{% endif %}


def test_{{ method.name|snake_case }}_rest_from_dict():
test_{{ method.name|snake_case }}_rest(request_type=dict)


{% if method.flattened_fields %}
def test_{{ method.name|snake_case }}_rest_flattened():
def test_{{ method.name|snake_case }}_rest_flattened(transport: str = 'rest'):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Mock the http request call within the method and fake a response.
Expand Down Expand Up @@ -1242,9 +1282,10 @@ def test_{{ method.name|snake_case }}_rest_flattened():
{# TODO(kbandes) - reverse-transcode request args to check all request fields #}


def test_{{ method.name|snake_case }}_rest_flattened_error():
def test_{{ method.name|snake_case }}_rest_flattened_error(transport: str = 'rest'):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Attempting to call a method with both a request object and flattened
Expand Down Expand Up @@ -1460,7 +1501,8 @@ def test_transport_get_channel():
{% if 'grpc' in opts.transport %}
transports.{{ service.grpc_transport_name }},
transports.{{ service.grpc_asyncio_transport_name }},
{% elif 'rest' in opts.transport %}
{% endif %}
{% if 'rest' in opts.transport %}
transports.{{ service.rest_transport_name }},
{% endif %}
])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_asset_service_client_client_options(client_class, transport_class, tran
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -178,7 +178,7 @@ def test_asset_service_client_client_options(client_class, transport_class, tran
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -195,7 +195,7 @@ def test_asset_service_client_client_options(client_class, transport_class, tran
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -222,7 +222,7 @@ def test_asset_service_client_client_options(client_class, transport_class, tran
options = client_options.ClientOptions(quota_project_id="octopus")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_asset_service_client_mtls_env_auto(client_class, transport_class, trans
options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)

if use_client_cert_env == "false":
expected_client_cert_source = None
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_asset_service_client_mtls_env_auto(client_class, transport_class, trans
expected_client_cert_source = client_cert_source_callback

patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -304,7 +304,7 @@ def test_asset_service_client_mtls_env_auto(client_class, transport_class, trans
with mock.patch.object(transport_class, '__init__') as patched:
with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False):
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -328,7 +328,7 @@ def test_asset_service_client_client_options_scopes(client_class, transport_clas
)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -351,7 +351,7 @@ def test_asset_service_client_client_options_credentials_file(client_class, tran
)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_iam_credentials_client_client_options(client_class, transport_class, tr
options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -170,7 +170,7 @@ def test_iam_credentials_client_client_options(client_class, transport_class, tr
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -187,7 +187,7 @@ def test_iam_credentials_client_client_options(client_class, transport_class, tr
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -214,7 +214,7 @@ def test_iam_credentials_client_client_options(client_class, transport_class, tr
options = client_options.ClientOptions(quota_project_id="octopus")
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_iam_credentials_client_mtls_env_auto(client_class, transport_class, tra
options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)

if use_client_cert_env == "false":
expected_client_cert_source = None
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_iam_credentials_client_mtls_env_auto(client_class, transport_class, tra
expected_client_cert_source = client_cert_source_callback

patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -296,7 +296,7 @@ def test_iam_credentials_client_mtls_env_auto(client_class, transport_class, tra
with mock.patch.object(transport_class, '__init__') as patched:
with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False):
patched.return_value = None
client = client_class()
client = client_class(transport=transport_name)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -320,7 +320,7 @@ def test_iam_credentials_client_client_options_scopes(client_class, transport_cl
)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file=None,
Expand All @@ -343,7 +343,7 @@ def test_iam_credentials_client_client_options_credentials_file(client_class, tr
)
with mock.patch.object(transport_class, '__init__') as patched:
patched.return_value = None
client = client_class(client_options=options)
client = client_class(transport=transport_name, client_options=options)
patched.assert_called_once_with(
credentials=None,
credentials_file="credentials.json",
Expand Down
Loading

0 comments on commit 8f213ad

Please sign in to comment.