diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 8adc535e..5980f825 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps"] + option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps", "_with_auth_aio"] python: - "3.7" - "3.8" diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py new file mode 100644 index 00000000..3bc87a96 --- /dev/null +++ b/google/api_core/_rest_streaming_base.py @@ -0,0 +1,118 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for server-side streaming in REST.""" + +from collections import deque +import string +from typing import Deque, Union +import types + +import proto +import google.protobuf.message +from google.protobuf.json_format import Parse + + +class BaseResponseIterator: + """Base Iterator over REST API responses. This class should not be used directly. + + Args: + response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response + class expected to be returned from an API. + + Raises: + ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + """ + + def __init__( + self, + response_message_cls: Union[proto.Message, google.protobuf.message.Message], + ): + self._response_message_cls = response_message_cls + # Contains a list of JSON responses ready to be sent to user. + self._ready_objs: Deque[str] = deque() + # Current JSON response being built. + self._obj = "" + # Keeps track of the nesting level within a JSON object. + self._level = 0 + # Keeps track whether HTTP response is currently sending values + # inside of a string value. + self._in_string = False + # Whether an escape symbol "\" was encountered. + self._escape_next = False + + self._grab = types.MethodType(self._create_grab(), self) + + def _process_chunk(self, chunk: str): + if self._level == 0: + if chunk[0] != "[": + raise ValueError( + "Can only parse array of JSON objects, instead got %s" % chunk + ) + for char in chunk: + if char == "{": + if self._level == 1: + # Level 1 corresponds to the outermost JSON object + # (i.e. the one we care about). + self._obj = "" + if not self._in_string: + self._level += 1 + self._obj += char + elif char == "}": + self._obj += char + if not self._in_string: + self._level -= 1 + if not self._in_string and self._level == 1: + self._ready_objs.append(self._obj) + elif char == '"': + # Helps to deal with an escaped quotes inside of a string. + if not self._escape_next: + self._in_string = not self._in_string + self._obj += char + elif char in string.whitespace: + if self._in_string: + self._obj += char + elif char == "[": + if self._level == 0: + self._level += 1 + else: + self._obj += char + elif char == "]": + if self._level == 1: + self._level -= 1 + else: + self._obj += char + else: + self._obj += char + self._escape_next = not self._escape_next if char == "\\" else False + + def _create_grab(self): + if issubclass(self._response_message_cls, proto.Message): + + def grab(this): + return this._response_message_cls.from_json( + this._ready_objs.popleft(), ignore_unknown_fields=True + ) + + return grab + elif issubclass(self._response_message_cls, google.protobuf.message.Message): + + def grab(this): + return Parse(this._ready_objs.popleft(), this._response_message_cls()) + + return grab + else: + raise ValueError( + "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." + ) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 88bcb31b..84aa270c 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -14,17 +14,15 @@ """Helpers for server-side streaming in REST.""" -from collections import deque -import string -from typing import Deque, Union +from typing import Union import proto import requests import google.protobuf.message -from google.protobuf.json_format import Parse +from google.api_core._rest_streaming_base import BaseResponseIterator -class ResponseIterator: +class ResponseIterator(BaseResponseIterator): """Iterator over REST API responses. Args: @@ -33,7 +31,8 @@ class ResponseIterator: class expected to be returned from an API. Raises: - ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + ValueError: + - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. """ def __init__( @@ -42,68 +41,16 @@ def __init__( response_message_cls: Union[proto.Message, google.protobuf.message.Message], ): self._response = response - self._response_message_cls = response_message_cls # Inner iterator over HTTP response's content. self._response_itr = self._response.iter_content(decode_unicode=True) - # Contains a list of JSON responses ready to be sent to user. - self._ready_objs: Deque[str] = deque() - # Current JSON response being built. - self._obj = "" - # Keeps track of the nesting level within a JSON object. - self._level = 0 - # Keeps track whether HTTP response is currently sending values - # inside of a string value. - self._in_string = False - # Whether an escape symbol "\" was encountered. - self._escape_next = False + super(ResponseIterator, self).__init__( + response_message_cls=response_message_cls + ) def cancel(self): """Cancel existing streaming operation.""" self._response.close() - def _process_chunk(self, chunk: str): - if self._level == 0: - if chunk[0] != "[": - raise ValueError( - "Can only parse array of JSON objects, instead got %s" % chunk - ) - for char in chunk: - if char == "{": - if self._level == 1: - # Level 1 corresponds to the outermost JSON object - # (i.e. the one we care about). - self._obj = "" - if not self._in_string: - self._level += 1 - self._obj += char - elif char == "}": - self._obj += char - if not self._in_string: - self._level -= 1 - if not self._in_string and self._level == 1: - self._ready_objs.append(self._obj) - elif char == '"': - # Helps to deal with an escaped quotes inside of a string. - if not self._escape_next: - self._in_string = not self._in_string - self._obj += char - elif char in string.whitespace: - if self._in_string: - self._obj += char - elif char == "[": - if self._level == 0: - self._level += 1 - else: - self._obj += char - elif char == "]": - if self._level == 1: - self._level -= 1 - else: - self._obj += char - else: - self._obj += char - self._escape_next = not self._escape_next if char == "\\" else False - def __next__(self): while not self._ready_objs: try: @@ -115,18 +62,5 @@ def __next__(self): raise e return self._grab() - def _grab(self): - # Add extra quotes to make json.loads happy. - if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json( - self._ready_objs.popleft(), ignore_unknown_fields=True - ) - elif issubclass(self._response_message_cls, google.protobuf.message.Message): - return Parse(self._ready_objs.popleft(), self._response_message_cls()) - else: - raise ValueError( - "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." - ) - def __iter__(self): return self diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py new file mode 100644 index 00000000..d1f996f6 --- /dev/null +++ b/google/api_core/rest_streaming_async.py @@ -0,0 +1,83 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for asynchronous server-side streaming in REST.""" + +from typing import Union + +import proto + +try: + import google.auth.aio.transport +except ImportError as e: # pragma: NO COVER + raise ImportError( + "google-auth>=2.35.0 is required to use asynchronous rest streaming." + ) from e + +import google.protobuf.message +from google.api_core._rest_streaming_base import BaseResponseIterator + + +class AsyncResponseIterator(BaseResponseIterator): + """Asynchronous Iterator over REST API responses. + + Args: + response (google.auth.aio.transport.Response): An API response object. + response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response + class expected to be returned from an API. + + Raises: + ValueError: + - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + """ + + def __init__( + self, + response: google.auth.aio.transport.Response, + response_message_cls: Union[proto.Message, google.protobuf.message.Message], + ): + self._response = response + self._chunk_size = 1024 + self._response_itr = self._response.content().__aiter__() + super(AsyncResponseIterator, self).__init__( + response_message_cls=response_message_cls + ) + + async def __aenter__(self): + return self + + async def cancel(self): + """Cancel existing streaming operation.""" + await self._response.close() + + async def __anext__(self): + while not self._ready_objs: + try: + chunk = await self._response_itr.__anext__() + chunk = chunk.decode("utf-8") + self._process_chunk(chunk) + except StopAsyncIteration as e: + if self._level > 0: + raise ValueError("i Unfinished stream: %s" % self._obj) + raise e + except ValueError as e: + raise e + return self._grab() + + def __aiter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + """Cancel existing async streaming operation.""" + await self._response.close() diff --git a/noxfile.py b/noxfile.py index a15795ea..3fc4a722 100644 --- a/noxfile.py +++ b/noxfile.py @@ -38,6 +38,7 @@ "unit", "unit_grpc_gcp", "unit_wo_grpc", + "unit_with_auth_aio", "cover", "pytype", "mypy", @@ -109,7 +110,7 @@ def install_prerelease_dependencies(session, constraints_path): session.install(*other_deps) -def default(session, install_grpc=True, prerelease=False): +def default(session, install_grpc=True, prerelease=False, install_auth_aio=False): """Default unit test session. This is intended to be run **without** an interpreter set, so @@ -144,6 +145,11 @@ def default(session, install_grpc=True, prerelease=False): f"{constraints_dir}/constraints-{session.python}.txt", ) + if install_auth_aio: + session.install( + "google-auth @ git+https://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad" + ) + # Print out package versions of dependencies session.run( "python", "-c", "import google.protobuf; print(google.protobuf.__version__)" @@ -229,6 +235,12 @@ def unit_wo_grpc(session): default(session, install_grpc=False) +@nox.session(python=PYTHON_VERSIONS) +def unit_with_auth_aio(session): + """Run the unit test suite with google.auth.aio installed""" + default(session, install_auth_aio=True) + + @nox.session(python=DEFAULT_PYTHON_VERSION) def lint_setup_py(session): """Verify that setup.py is valid (including RST check).""" diff --git a/tests/asyncio/test_rest_streaming_async.py b/tests/asyncio/test_rest_streaming_async.py new file mode 100644 index 00000000..35820de6 --- /dev/null +++ b/tests/asyncio/test_rest_streaming_async.py @@ -0,0 +1,378 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: set random.seed explicitly in each test function. +# See related issue: https://github.com/googleapis/python-api-core/issues/689. + +import pytest # noqa: I202 +import mock + +import datetime +import logging +import random +import time +from typing import List, AsyncIterator + +import proto + +try: + from google.auth.aio.transport import Response + + AUTH_AIO_INSTALLED = True +except ImportError: + AUTH_AIO_INSTALLED = False + +if not AUTH_AIO_INSTALLED: # pragma: NO COVER + pytest.skip( + "google-auth>=2.35.0 is required to use asynchronous rest streaming.", + allow_module_level=True, + ) + +from google.api_core import rest_streaming_async +from google.api import http_pb2 +from google.api import httpbody_pb2 + + +from ..helpers import Composer, Song, EchoResponse, parse_responses + + +__protobuf__ = proto.module(package=__name__) +SEED = int(time.time()) +logging.info(f"Starting async rest streaming tests with random seed: {SEED}") +random.seed(SEED) + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +class ResponseMock(Response): + class _ResponseItr(AsyncIterator[bytes]): + def __init__(self, _response_bytes: bytes, random_split=False): + self._responses_bytes = _response_bytes + self._idx = 0 + self._random_split = random_split + + def __aiter__(self): + return self + + async def __anext__(self): + if self._idx >= len(self._responses_bytes): + raise StopAsyncIteration + if self._random_split: + n = random.randint(1, len(self._responses_bytes[self._idx :])) + else: + n = 1 + x = self._responses_bytes[self._idx : self._idx + n] + self._idx += n + return x + + def __init__( + self, + responses: List[proto.Message], + response_cls, + random_split=False, + ): + self._responses = responses + self._random_split = random_split + self._response_message_cls = response_cls + + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) + + @property + async def headers(self): + raise NotImplementedError() + + @property + async def status_code(self): + raise NotImplementedError() + + async def close(self): + raise NotImplementedError() + + async def content(self, chunk_size=None): + itr = self._ResponseItr( + self._parse_responses(), random_split=self._random_split + ) + async for chunk in itr: + yield chunk + + async def read(self): + raise NotImplementedError() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [(False, True), (False, False)], +) +async def test_next_simple(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = EchoResponse + responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] + else: + response_type = httpbody_pb2.HttpBody + responses = [ + httpbody_pb2.HttpBody(content_type="hello world"), + httpbody_pb2.HttpBody(content_type="yes"), + ] + + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +async def test_next_nested(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="some song", composer=Composer(given_name="some name")), + Song(title="another song", date_added=datetime.datetime(2021, 12, 17)), + ] + else: + # Although `http_pb2.HttpRule`` is used in the response, any response message + # can be used which meets this criteria for the test of having a nested field. + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="some selector", + custom=http_pb2.CustomHttpPattern(kind="some kind"), + ), + http_pb2.HttpRule( + selector="another selector", + custom=http_pb2.CustomHttpPattern(path="some path"), + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + assert idx == len(responses) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +async def test_next_stress(random_split, resp_message_is_proto_plus): + n = 50 + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) + for i in range(n) + ] + else: + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="selector_%d" % i, + custom=http_pb2.CustomHttpPattern(path="path_%d" % i), + ) + for i in range(n) + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + assert idx == n + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +async def test_next_escaped_characters_in_string( + random_split, resp_message_is_proto_plus +): + if resp_message_is_proto_plus: + response_type = Song + composer_with_relateds = Composer() + relateds = ["Artist A", "Artist B"] + composer_with_relateds.relateds = relateds + + responses = [ + Song( + title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n") + ), + Song( + title='{"this is weird": "totally"}', + composer=Composer(given_name="\\{}\\"), + ), + Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds), + ] + else: + response_type = http_pb2.Http + responses = [ + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='ti"tle\nfoo\tbar{}', + custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='{"this is weird": "totally"}', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='\\{"key": ["value",]}\\', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + idx = 0 + async for response in itr: + assert response == responses[idx] + idx += 1 + assert idx == len(responses) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_next_not_array(response_type): + + data = '{"hello": 0}' + with mock.patch.object( + ResponseMock, "content", return_value=mock_async_gen(data) + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + with pytest.raises(ValueError): + await itr.__anext__() + mock_method.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_cancel(response_type): + with mock.patch.object( + ResponseMock, "close", new_callable=mock.AsyncMock + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + await itr.cancel() + mock_method.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_iterator_as_context_manager(response_type): + with mock.patch.object( + ResponseMock, "close", new_callable=mock.AsyncMock + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + async with rest_streaming_async.AsyncResponseIterator(resp, response_type): + pass + mock_method.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "response_type,return_value", + [ + (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")), + (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")), + ], +) +async def test_check_buffer(response_type, return_value): + with mock.patch.object( + ResponseMock, + "_parse_responses", + return_value=return_value, + ): + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + with pytest.raises(ValueError): + await itr.__anext__() + await itr.__anext__() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_next_html(response_type): + + data = "" + with mock.patch.object( + ResponseMock, "content", return_value=mock_async_gen(data) + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) + with pytest.raises(ValueError): + await itr.__anext__() + mock_method.assert_called_once() + + +@pytest.mark.asyncio +async def test_invalid_response_class(): + class SomeClass: + pass + + resp = ResponseMock(responses=[], response_cls=SomeClass) + with pytest.raises( + ValueError, + match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", + ): + rest_streaming_async.AsyncResponseIterator(resp, SomeClass) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..3429d511 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,71 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for tests""" + +import logging +from typing import List + +import proto + +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf.json_format import MessageToJson + + +class Genre(proto.Enum): + GENRE_UNSPECIFIED = 0 + CLASSICAL = 1 + JAZZ = 2 + ROCK = 3 + + +class Composer(proto.Message): + given_name = proto.Field(proto.STRING, number=1) + family_name = proto.Field(proto.STRING, number=2) + relateds = proto.RepeatedField(proto.STRING, number=3) + indices = proto.MapField(proto.STRING, proto.STRING, number=4) + + +class Song(proto.Message): + composer = proto.Field(Composer, number=1) + title = proto.Field(proto.STRING, number=2) + lyrics = proto.Field(proto.STRING, number=3) + year = proto.Field(proto.INT32, number=4) + genre = proto.Field(Genre, number=5) + is_five_mins_longer = proto.Field(proto.BOOL, number=6) + score = proto.Field(proto.DOUBLE, number=7) + likes = proto.Field(proto.INT64, number=8) + duration = proto.Field(duration_pb2.Duration, number=9) + date_added = proto.Field(timestamp_pb2.Timestamp, number=10) + + +class EchoResponse(proto.Message): + content = proto.Field(proto.STRING, number=1) + + +def parse_responses(response_message_cls, all_responses: List[proto.Message]) -> bytes: + # json.dumps returns a string surrounded with quotes that need to be stripped + # in order to be an actual JSON. + json_responses = [ + ( + response_message_cls.to_json(response).strip('"') + if issubclass(response_message_cls, proto.Message) + else MessageToJson(response).strip('"') + ) + for response in all_responses + ] + logging.info(f"Sending JSON stream: {json_responses}") + ret_val = "[{}]".format(",".join(json_responses)) + return bytes(ret_val, "utf-8") diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 0f2b3b32..0f998dfe 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -26,48 +26,16 @@ from google.api_core import rest_streaming from google.api import http_pb2 from google.api import httpbody_pb2 -from google.protobuf import duration_pb2 -from google.protobuf import timestamp_pb2 -from google.protobuf.json_format import MessageToJson + +from ..helpers import Composer, Song, EchoResponse, parse_responses __protobuf__ = proto.module(package=__name__) SEED = int(time.time()) -logging.info(f"Starting rest streaming tests with random seed: {SEED}") +logging.info(f"Starting sync rest streaming tests with random seed: {SEED}") random.seed(SEED) -class Genre(proto.Enum): - GENRE_UNSPECIFIED = 0 - CLASSICAL = 1 - JAZZ = 2 - ROCK = 3 - - -class Composer(proto.Message): - given_name = proto.Field(proto.STRING, number=1) - family_name = proto.Field(proto.STRING, number=2) - relateds = proto.RepeatedField(proto.STRING, number=3) - indices = proto.MapField(proto.STRING, proto.STRING, number=4) - - -class Song(proto.Message): - composer = proto.Field(Composer, number=1) - title = proto.Field(proto.STRING, number=2) - lyrics = proto.Field(proto.STRING, number=3) - year = proto.Field(proto.INT32, number=4) - genre = proto.Field(Genre, number=5) - is_five_mins_longer = proto.Field(proto.BOOL, number=6) - score = proto.Field(proto.DOUBLE, number=7) - likes = proto.Field(proto.INT64, number=8) - duration = proto.Field(duration_pb2.Duration, number=9) - date_added = proto.Field(timestamp_pb2.Timestamp, number=10) - - -class EchoResponse(proto.Message): - content = proto.Field(proto.STRING, number=1) - - class ResponseMock(requests.Response): class _ResponseItr: def __init__(self, _response_bytes: bytes, random_split=False): @@ -97,27 +65,15 @@ def __init__( self._random_split = random_split self._response_message_cls = response_cls - def _parse_responses(self, responses: List[proto.Message]) -> bytes: - # json.dumps returns a string surrounded with quotes that need to be stripped - # in order to be an actual JSON. - json_responses = [ - ( - self._response_message_cls.to_json(r).strip('"') - if issubclass(self._response_message_cls, proto.Message) - else MessageToJson(r).strip('"') - ) - for r in responses - ] - logging.info(f"Sending JSON stream: {json_responses}") - ret_val = "[{}]".format(",".join(json_responses)) - return bytes(ret_val, "utf-8") + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) def close(self): raise NotImplementedError() def iter_content(self, *args, **kwargs): return self._ResponseItr( - self._parse_responses(self._responses), + self._parse_responses(), random_split=self._random_split, ) @@ -333,9 +289,8 @@ class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) - response_iterator = rest_streaming.ResponseIterator(resp, SomeClass) with pytest.raises( ValueError, match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", ): - response_iterator._grab() + rest_streaming.ResponseIterator(resp, SomeClass)