Skip to content

Commit

Permalink
fix(v1): resolve issue handling protobuf responses in rest streaming (#…
Browse files Browse the repository at this point in the history
…609)

* fix: resolve issue handling protobuf responses in rest streaming

* raise ValueError if response_message_cls is not a subclass of proto.Message or google.protobuf.message.Message

* remove response_type from pytest.mark.parametrize

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* add test for ValueError in response_iterator._grab()

---------

Co-authored-by: Anthonios Partheniou <partheniou@google.com>
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 13, 2024
1 parent e962dee commit d386d2e
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 56 deletions.
25 changes: 21 additions & 4 deletions google/api_core/rest_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,31 @@

from collections import deque
import string
from typing import Deque
from typing import Deque, Union

import proto
import requests
import google.protobuf.message
from google.protobuf.json_format import Parse


class ResponseIterator:
"""Iterator over REST API responses.
Args:
response (requests.Response): An API response object.
response_message_cls (Callable[proto.Message]): A proto
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
class expected to be returned from an API.
Raises:
ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
"""

def __init__(self, response: requests.Response, response_message_cls):
def __init__(
self,
response: requests.Response,
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
Expand Down Expand Up @@ -107,7 +117,14 @@ def __next__(self):

def _grab(self):
# Add extra quotes to make json.loads happy.
return self._response_message_cls.from_json(self._ready_objs.popleft())
if issubclass(self._response_message_cls, proto.Message):
return self._response_message_cls.from_json(self._ready_objs.popleft())
elif issubclass(self._response_message_cls, google.protobuf.message.Message):
return Parse(self._ready_objs.popleft(), self._response_message_cls())
else:
raise ValueError(
"Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
)

def __iter__(self):
return self
227 changes: 175 additions & 52 deletions tests/unit/test_rest_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import requests

from google.api_core import rest_streaming
from google.api import http_pb2
from google.api import httpbody_pb2
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from google.protobuf.json_format import MessageToJson


__protobuf__ = proto.module(package=__name__)
Expand Down Expand Up @@ -98,7 +101,10 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes:
# json.dumps returns a string surrounded with quotes that need to be stripped
# in order to be an actual JSON.
json_responses = [
self._response_message_cls.to_json(r).strip('"') for r in responses
self._response_message_cls.to_json(r).strip('"')
if issubclass(self._response_message_cls, proto.Message)
else MessageToJson(r).strip('"')
for r in responses
]
logging.info(f"Sending JSON stream: {json_responses}")
ret_val = "[{}]".format(",".join(json_responses))
Expand All @@ -114,103 +120,220 @@ def iter_content(self, *args, **kwargs):
)


@pytest.mark.parametrize("random_split", [False])
def test_next_simple(random_split):
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[(False, True), (False, False)],
)
def test_next_simple(random_split, resp_message_is_proto_plus):
if resp_message_is_proto_plus:
response_type = EchoResponse
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
else:
response_type = httpbody_pb2.HttpBody
responses = [
httpbody_pb2.HttpBody(content_type="hello world"),
httpbody_pb2.HttpBody(content_type="yes"),
]

resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=EchoResponse
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_nested(random_split):
responses = [
Song(title="some song", composer=Composer(given_name="some name")),
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[
(True, True),
(False, True),
(True, False),
(False, False),
],
)
def test_next_nested(random_split, resp_message_is_proto_plus):
if resp_message_is_proto_plus:
response_type = Song
responses = [
Song(title="some song", composer=Composer(given_name="some name")),
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
]
else:
# Although `http_pb2.HttpRule`` is used in the response, any response message
# can be used which meets this criteria for the test of having a nested field.
response_type = http_pb2.HttpRule
responses = [
http_pb2.HttpRule(
selector="some selector",
custom=http_pb2.CustomHttpPattern(kind="some kind"),
),
http_pb2.HttpRule(
selector="another selector",
custom=http_pb2.CustomHttpPattern(path="some path"),
),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_stress(random_split):
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[
(True, True),
(False, True),
(True, False),
(False, False),
],
)
def test_next_stress(random_split, resp_message_is_proto_plus):
n = 50
responses = [
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
for i in range(n)
]
if resp_message_is_proto_plus:
response_type = Song
responses = [
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
for i in range(n)
]
else:
response_type = http_pb2.HttpRule
responses = [
http_pb2.HttpRule(
selector="selector_%d" % i,
custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
)
for i in range(n)
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_escaped_characters_in_string(random_split):
composer_with_relateds = Composer()
relateds = ["Artist A", "Artist B"]
composer_with_relateds.relateds = relateds

responses = [
Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
Song(
title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
),
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[
(True, True),
(False, True),
(True, False),
(False, False),
],
)
def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus):
if resp_message_is_proto_plus:
response_type = Song
composer_with_relateds = Composer()
relateds = ["Artist A", "Artist B"]
composer_with_relateds.relateds = relateds

responses = [
Song(
title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
),
Song(
title='{"this is weird": "totally"}',
composer=Composer(given_name="\\{}\\"),
),
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
]
else:
response_type = http_pb2.Http
responses = [
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='ti"tle\nfoo\tbar{}',
custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
)
]
),
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='{"this is weird": "totally"}',
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
)
]
),
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='\\{"key": ["value",]}\\',
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
)
]
),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


def test_next_not_array():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_next_not_array(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()


def test_cancel():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_cancel(response_type):
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
itr.cancel()
mock_method.assert_called_once()


def test_check_buffer():
@pytest.mark.parametrize(
"response_type,return_value",
[
(EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
(httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
],
)
def test_check_buffer(response_type, return_value):
with patch.object(
ResponseMock,
"_parse_responses",
return_value=bytes('[{"content": "hello"}, {', "utf-8"),
return_value=return_value,
):
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
next(itr)


def test_next_html():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_next_html(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()


def test_invalid_response_class():
class SomeClass:
pass

resp = ResponseMock(responses=[], response_cls=SomeClass)
response_iterator = rest_streaming.ResponseIterator(resp, SomeClass)
with pytest.raises(
ValueError,
match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
):
response_iterator._grab()

0 comments on commit d386d2e

Please sign in to comment.