Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: resolve issue handling protobuf responses in rest streaming #604

Merged
merged 5 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions google/api_core/rest_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,28 @@

from collections import deque
import string
from typing import Deque
from typing import Deque, Union

import proto
import requests
import google.protobuf.message
from google.protobuf.json_format import Parse


class ResponseIterator:
"""Iterator over REST API responses.

Args:
response (requests.Response): An API response object.
response_message_cls (Callable[proto.Message]): A proto
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
class expected to be returned from an API.
"""

def __init__(self, response: requests.Response, response_message_cls):
def __init__(
self,
response: requests.Response,
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
Expand Down Expand Up @@ -107,7 +114,10 @@ def __next__(self):

def _grab(self):
# Add extra quotes to make json.loads happy.
return self._response_message_cls.from_json(self._ready_objs.popleft())
if issubclass(self._response_message_cls, proto.Message):
return self._response_message_cls.from_json(self._ready_objs.popleft())
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for safety and future-proofing, maybe this should be an elif and we check for google.protobuf.Message
(I realize we don't expect to ever have the second check fail if we get here....but "we don't expect" == "famous last words")

(not a blocker)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9f5b22f

return Parse(self._ready_objs.popleft(), self._response_message_cls())

def __iter__(self):
return self
208 changes: 156 additions & 52 deletions tests/unit/test_rest_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import requests

from google.api_core import rest_streaming
from google.api import http_pb2
from google.api import httpbody_pb2
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from google.protobuf.json_format import MessageToJson


__protobuf__ = proto.module(package=__name__)
Expand Down Expand Up @@ -98,7 +101,10 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes:
# json.dumps returns a string surrounded with quotes that need to be stripped
# in order to be an actual JSON.
json_responses = [
self._response_message_cls.to_json(r).strip('"') for r in responses
self._response_message_cls.to_json(r).strip('"')
if issubclass(self._response_message_cls, proto.Message)
else MessageToJson(r).strip('"')
for r in responses
]
logging.info(f"Sending JSON stream: {json_responses}")
ret_val = "[{}]".format(",".join(json_responses))
Expand All @@ -114,103 +120,201 @@ def iter_content(self, *args, **kwargs):
)


@pytest.mark.parametrize("random_split", [False])
def test_next_simple(random_split):
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus,response_type",
[(False, True, EchoResponse), (False, False, httpbody_pb2.HttpBody)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that you're not testing different combinations of resp_message_is_proto_plus and response_type, and that you're already switching on the first to construct the responses (which you have to because they take different parameters): I suggest you don't parametrize on response_type and instead, in the function, do

if resp_message_is_proto_plus:
  response_type = EchoResponse
  responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
else:
  response_type = httpbody_pb2.HttpBody
  responses = [
            httpbody_pb2.HttpBody(content_type="hello world"),
            httpbody_pb2.HttpBody(content_type="yes"),
        ]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5b7a4ce

)
def test_next_simple(random_split, resp_message_is_proto_plus, response_type):
if resp_message_is_proto_plus:
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
else:
responses = [
httpbody_pb2.HttpBody(content_type="hello world"),
httpbody_pb2.HttpBody(content_type="yes"),
]

resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=EchoResponse
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_nested(random_split):
responses = [
Song(title="some song", composer=Composer(given_name="some name")),
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus,response_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above: since you have the if below already, don't parametrize on response_type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5b7a4ce

[
(True, True, Song),
(False, True, Song),
(True, False, http_pb2.HttpRule),
(False, False, http_pb2.HttpRule),
],
)
def test_next_nested(random_split, resp_message_is_proto_plus, response_type):
if resp_message_is_proto_plus:
responses = [
Song(title="some song", composer=Composer(given_name="some name")),
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
]
else:
# Although `http_pb2.HttpRule`` is used in the response, any response message
# can be used which meets this criteria for the test of having a nested field.
responses = [
http_pb2.HttpRule(
selector="some selector",
custom=http_pb2.CustomHttpPattern(kind="some kind"),
),
http_pb2.HttpRule(
selector="another selector",
custom=http_pb2.CustomHttpPattern(path="some path"),
),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_stress(random_split):
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus,response_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on parametrization

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5b7a4ce

[
(True, True, Song),
(False, True, Song),
(True, False, http_pb2.HttpRule),
(False, False, http_pb2.HttpRule),
],
)
def test_next_stress(random_split, resp_message_is_proto_plus, response_type):
n = 50
responses = [
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
for i in range(n)
]
if resp_message_is_proto_plus:
responses = [
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
for i in range(n)
]
else:
responses = [
http_pb2.HttpRule(
selector="selector_%d" % i,
custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
)
for i in range(n)
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_escaped_characters_in_string(random_split):
composer_with_relateds = Composer()
relateds = ["Artist A", "Artist B"]
composer_with_relateds.relateds = relateds

responses = [
Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
Song(
title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
),
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
]
@pytest.mark.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on parametrizing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5b7a4ce

"random_split,resp_message_is_proto_plus,response_type",
[
(True, True, Song),
(False, True, Song),
(True, False, http_pb2.Http),
(False, False, http_pb2.Http),
],
)
def test_next_escaped_characters_in_string(
random_split, resp_message_is_proto_plus, response_type
):
if resp_message_is_proto_plus:
composer_with_relateds = Composer()
relateds = ["Artist A", "Artist B"]
composer_with_relateds.relateds = relateds

responses = [
Song(
title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
),
Song(
title='{"this is weird": "totally"}',
composer=Composer(given_name="\\{}\\"),
),
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
]
else:
responses = [
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='ti"tle\nfoo\tbar{}',
custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
)
]
),
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='{"this is weird": "totally"}',
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
)
]
),
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='\\{"key": ["value",]}\\',
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
)
]
),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


def test_next_not_array():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_next_not_array(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()


def test_cancel():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_cancel(response_type):
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
itr.cancel()
mock_method.assert_called_once()


def test_check_buffer():
@pytest.mark.parametrize(
"response_type,return_value",
[
(EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
(httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
],
)
def test_check_buffer(response_type, return_value):
with patch.object(
ResponseMock,
"_parse_responses",
return_value=bytes('[{"content": "hello"}, {', "utf-8"),
return_value=return_value,
):
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
next(itr)


def test_next_html():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_next_html(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()
Loading