From 4df21201707bc61d3eeddc19570c6996a3dc371a Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:16:11 +0000 Subject: [PATCH 01/24] duplicating file to base --- google/api_core/{rest_streaming.py => _rest_streaming_base.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename google/api_core/{rest_streaming.py => _rest_streaming_base.py} (100%) diff --git a/google/api_core/rest_streaming.py b/google/api_core/_rest_streaming_base.py similarity index 100% rename from google/api_core/rest_streaming.py rename to google/api_core/_rest_streaming_base.py From 26f52a57334a25a7b6e40ce1a9b4f5d39d2460f1 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:20:31 +0000 Subject: [PATCH 02/24] restore original file --- google/api_core/rest_streaming.py | 132 ++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 google/api_core/rest_streaming.py diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py new file mode 100644 index 00000000..88bcb31b --- /dev/null +++ b/google/api_core/rest_streaming.py @@ -0,0 +1,132 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for server-side streaming in REST.""" + +from collections import deque +import string +from typing import Deque, Union + +import proto +import requests +import google.protobuf.message +from google.protobuf.json_format import Parse + + +class ResponseIterator: + """Iterator over REST API responses. + + Args: + response (requests.Response): An API response object. + response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response + class expected to be returned from an API. + + Raises: + ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + """ + + def __init__( + self, + response: requests.Response, + response_message_cls: Union[proto.Message, google.protobuf.message.Message], + ): + self._response = response + self._response_message_cls = response_message_cls + # Inner iterator over HTTP response's content. + self._response_itr = self._response.iter_content(decode_unicode=True) + # Contains a list of JSON responses ready to be sent to user. + self._ready_objs: Deque[str] = deque() + # Current JSON response being built. + self._obj = "" + # Keeps track of the nesting level within a JSON object. + self._level = 0 + # Keeps track whether HTTP response is currently sending values + # inside of a string value. + self._in_string = False + # Whether an escape symbol "\" was encountered. + self._escape_next = False + + def cancel(self): + """Cancel existing streaming operation.""" + self._response.close() + + def _process_chunk(self, chunk: str): + if self._level == 0: + if chunk[0] != "[": + raise ValueError( + "Can only parse array of JSON objects, instead got %s" % chunk + ) + for char in chunk: + if char == "{": + if self._level == 1: + # Level 1 corresponds to the outermost JSON object + # (i.e. the one we care about). + self._obj = "" + if not self._in_string: + self._level += 1 + self._obj += char + elif char == "}": + self._obj += char + if not self._in_string: + self._level -= 1 + if not self._in_string and self._level == 1: + self._ready_objs.append(self._obj) + elif char == '"': + # Helps to deal with an escaped quotes inside of a string. + if not self._escape_next: + self._in_string = not self._in_string + self._obj += char + elif char in string.whitespace: + if self._in_string: + self._obj += char + elif char == "[": + if self._level == 0: + self._level += 1 + else: + self._obj += char + elif char == "]": + if self._level == 1: + self._level -= 1 + else: + self._obj += char + else: + self._obj += char + self._escape_next = not self._escape_next if char == "\\" else False + + def __next__(self): + while not self._ready_objs: + try: + chunk = next(self._response_itr) + self._process_chunk(chunk) + except StopIteration as e: + if self._level > 0: + raise ValueError("Unfinished stream: %s" % self._obj) + raise e + return self._grab() + + def _grab(self): + # Add extra quotes to make json.loads happy. + if issubclass(self._response_message_cls, proto.Message): + return self._response_message_cls.from_json( + self._ready_objs.popleft(), ignore_unknown_fields=True + ) + elif issubclass(self._response_message_cls, google.protobuf.message.Message): + return Parse(self._ready_objs.popleft(), self._response_message_cls()) + else: + raise ValueError( + "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." + ) + + def __iter__(self): + return self From 204920a3f2da427329f22f0a77dea03dd05b8610 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:22:10 +0000 Subject: [PATCH 03/24] duplicate file to async --- google/api_core/{rest_streaming.py => rest_streaming_async.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename google/api_core/{rest_streaming.py => rest_streaming_async.py} (100%) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming_async.py similarity index 100% rename from google/api_core/rest_streaming.py rename to google/api_core/rest_streaming_async.py From 2014ef6b63f4e1ea9d3fdd04f40306fbdc76ae17 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:22:35 +0000 Subject: [PATCH 04/24] restore original file --- google/api_core/rest_streaming.py | 132 ++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 google/api_core/rest_streaming.py diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py new file mode 100644 index 00000000..88bcb31b --- /dev/null +++ b/google/api_core/rest_streaming.py @@ -0,0 +1,132 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for server-side streaming in REST.""" + +from collections import deque +import string +from typing import Deque, Union + +import proto +import requests +import google.protobuf.message +from google.protobuf.json_format import Parse + + +class ResponseIterator: + """Iterator over REST API responses. + + Args: + response (requests.Response): An API response object. + response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response + class expected to be returned from an API. + + Raises: + ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + """ + + def __init__( + self, + response: requests.Response, + response_message_cls: Union[proto.Message, google.protobuf.message.Message], + ): + self._response = response + self._response_message_cls = response_message_cls + # Inner iterator over HTTP response's content. + self._response_itr = self._response.iter_content(decode_unicode=True) + # Contains a list of JSON responses ready to be sent to user. + self._ready_objs: Deque[str] = deque() + # Current JSON response being built. + self._obj = "" + # Keeps track of the nesting level within a JSON object. + self._level = 0 + # Keeps track whether HTTP response is currently sending values + # inside of a string value. + self._in_string = False + # Whether an escape symbol "\" was encountered. + self._escape_next = False + + def cancel(self): + """Cancel existing streaming operation.""" + self._response.close() + + def _process_chunk(self, chunk: str): + if self._level == 0: + if chunk[0] != "[": + raise ValueError( + "Can only parse array of JSON objects, instead got %s" % chunk + ) + for char in chunk: + if char == "{": + if self._level == 1: + # Level 1 corresponds to the outermost JSON object + # (i.e. the one we care about). + self._obj = "" + if not self._in_string: + self._level += 1 + self._obj += char + elif char == "}": + self._obj += char + if not self._in_string: + self._level -= 1 + if not self._in_string and self._level == 1: + self._ready_objs.append(self._obj) + elif char == '"': + # Helps to deal with an escaped quotes inside of a string. + if not self._escape_next: + self._in_string = not self._in_string + self._obj += char + elif char in string.whitespace: + if self._in_string: + self._obj += char + elif char == "[": + if self._level == 0: + self._level += 1 + else: + self._obj += char + elif char == "]": + if self._level == 1: + self._level -= 1 + else: + self._obj += char + else: + self._obj += char + self._escape_next = not self._escape_next if char == "\\" else False + + def __next__(self): + while not self._ready_objs: + try: + chunk = next(self._response_itr) + self._process_chunk(chunk) + except StopIteration as e: + if self._level > 0: + raise ValueError("Unfinished stream: %s" % self._obj) + raise e + return self._grab() + + def _grab(self): + # Add extra quotes to make json.loads happy. + if issubclass(self._response_message_cls, proto.Message): + return self._response_message_cls.from_json( + self._ready_objs.popleft(), ignore_unknown_fields=True + ) + elif issubclass(self._response_message_cls, google.protobuf.message.Message): + return Parse(self._ready_objs.popleft(), self._response_message_cls()) + else: + raise ValueError( + "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." + ) + + def __iter__(self): + return self From 8014f475195ccea4aab09bf747bc8add1eaacc71 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:33:13 +0000 Subject: [PATCH 05/24] duplicate test file for async --- .../unit/{test_rest_streaming.py => test_rest_streaming_async.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unit/{test_rest_streaming.py => test_rest_streaming_async.py} (100%) diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming_async.py similarity index 100% rename from tests/unit/test_rest_streaming.py rename to tests/unit/test_rest_streaming_async.py From e84f03af3776d06a48e21e0b2edf726fd53faacd Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:33:46 +0000 Subject: [PATCH 06/24] restore test file --- tests/unit/test_rest_streaming.py | 341 ++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 tests/unit/test_rest_streaming.py diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py new file mode 100644 index 00000000..0f2b3b32 --- /dev/null +++ b/tests/unit/test_rest_streaming.py @@ -0,0 +1,341 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import logging +import random +import time +from typing import List +from unittest.mock import patch + +import proto +import pytest +import requests + +from google.api_core import rest_streaming +from google.api import http_pb2 +from google.api import httpbody_pb2 +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf.json_format import MessageToJson + + +__protobuf__ = proto.module(package=__name__) +SEED = int(time.time()) +logging.info(f"Starting rest streaming tests with random seed: {SEED}") +random.seed(SEED) + + +class Genre(proto.Enum): + GENRE_UNSPECIFIED = 0 + CLASSICAL = 1 + JAZZ = 2 + ROCK = 3 + + +class Composer(proto.Message): + given_name = proto.Field(proto.STRING, number=1) + family_name = proto.Field(proto.STRING, number=2) + relateds = proto.RepeatedField(proto.STRING, number=3) + indices = proto.MapField(proto.STRING, proto.STRING, number=4) + + +class Song(proto.Message): + composer = proto.Field(Composer, number=1) + title = proto.Field(proto.STRING, number=2) + lyrics = proto.Field(proto.STRING, number=3) + year = proto.Field(proto.INT32, number=4) + genre = proto.Field(Genre, number=5) + is_five_mins_longer = proto.Field(proto.BOOL, number=6) + score = proto.Field(proto.DOUBLE, number=7) + likes = proto.Field(proto.INT64, number=8) + duration = proto.Field(duration_pb2.Duration, number=9) + date_added = proto.Field(timestamp_pb2.Timestamp, number=10) + + +class EchoResponse(proto.Message): + content = proto.Field(proto.STRING, number=1) + + +class ResponseMock(requests.Response): + class _ResponseItr: + def __init__(self, _response_bytes: bytes, random_split=False): + self._responses_bytes = _response_bytes + self._i = 0 + self._random_split = random_split + + def __next__(self): + if self._i == len(self._responses_bytes): + raise StopIteration + if self._random_split: + n = random.randint(1, len(self._responses_bytes[self._i :])) + else: + n = 1 + x = self._responses_bytes[self._i : self._i + n] + self._i += n + return x.decode("utf-8") + + def __init__( + self, + responses: List[proto.Message], + response_cls, + random_split=False, + ): + super().__init__() + self._responses = responses + self._random_split = random_split + self._response_message_cls = response_cls + + def _parse_responses(self, responses: List[proto.Message]) -> bytes: + # json.dumps returns a string surrounded with quotes that need to be stripped + # in order to be an actual JSON. + json_responses = [ + ( + self._response_message_cls.to_json(r).strip('"') + if issubclass(self._response_message_cls, proto.Message) + else MessageToJson(r).strip('"') + ) + for r in responses + ] + logging.info(f"Sending JSON stream: {json_responses}") + ret_val = "[{}]".format(",".join(json_responses)) + return bytes(ret_val, "utf-8") + + def close(self): + raise NotImplementedError() + + def iter_content(self, *args, **kwargs): + return self._ResponseItr( + self._parse_responses(self._responses), + random_split=self._random_split, + ) + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [(False, True), (False, False)], +) +def test_next_simple(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = EchoResponse + responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] + else: + response_type = httpbody_pb2.HttpBody + responses = [ + httpbody_pb2.HttpBody(content_type="hello world"), + httpbody_pb2.HttpBody(content_type="yes"), + ] + + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +def test_next_nested(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="some song", composer=Composer(given_name="some name")), + Song(title="another song", date_added=datetime.datetime(2021, 12, 17)), + ] + else: + # Although `http_pb2.HttpRule`` is used in the response, any response message + # can be used which meets this criteria for the test of having a nested field. + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="some selector", + custom=http_pb2.CustomHttpPattern(kind="some kind"), + ), + http_pb2.HttpRule( + selector="another selector", + custom=http_pb2.CustomHttpPattern(path="some path"), + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +def test_next_stress(random_split, resp_message_is_proto_plus): + n = 50 + if resp_message_is_proto_plus: + response_type = Song + responses = [ + Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) + for i in range(n) + ] + else: + response_type = http_pb2.HttpRule + responses = [ + http_pb2.HttpRule( + selector="selector_%d" % i, + custom=http_pb2.CustomHttpPattern(path="path_%d" % i), + ) + for i in range(n) + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize( + "random_split,resp_message_is_proto_plus", + [ + (True, True), + (False, True), + (True, False), + (False, False), + ], +) +def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus): + if resp_message_is_proto_plus: + response_type = Song + composer_with_relateds = Composer() + relateds = ["Artist A", "Artist B"] + composer_with_relateds.relateds = relateds + + responses = [ + Song( + title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n") + ), + Song( + title='{"this is weird": "totally"}', + composer=Composer(given_name="\\{}\\"), + ), + Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds), + ] + else: + response_type = http_pb2.Http + responses = [ + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='ti"tle\nfoo\tbar{}', + custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='{"this is weird": "totally"}', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + http_pb2.Http( + rules=[ + http_pb2.HttpRule( + selector='\\{"key": ["value",]}\\', + custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), + ) + ] + ), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=response_type + ) + itr = rest_streaming.ResponseIterator(resp, response_type) + assert list(itr) == responses + + +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +def test_next_not_array(response_type): + with patch.object( + ResponseMock, "iter_content", return_value=iter('{"hello": 0}') + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once() + + +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +def test_cancel(response_type): + with patch.object(ResponseMock, "close", return_value=None) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + itr.cancel() + mock_method.assert_called_once() + + +@pytest.mark.parametrize( + "response_type,return_value", + [ + (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")), + (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")), + ], +) +def test_check_buffer(response_type, return_value): + with patch.object( + ResponseMock, + "_parse_responses", + return_value=return_value, + ): + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + with pytest.raises(ValueError): + next(itr) + next(itr) + + +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +def test_next_html(response_type): + with patch.object( + ResponseMock, "iter_content", return_value=iter("") + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once() + + +def test_invalid_response_class(): + class SomeClass: + pass + + resp = ResponseMock(responses=[], response_cls=SomeClass) + response_iterator = rest_streaming.ResponseIterator(resp, SomeClass) + with pytest.raises( + ValueError, + match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", + ): + response_iterator._grab() From 0c2ee5cb595dbc6b527bf127e1f1026fbb6bcd54 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:38:35 +0000 Subject: [PATCH 07/24] feat: add support for asynchronous rest streaming --- google/api_core/_rest_streaming_base.py | 33 +--- google/api_core/rest_streaming.py | 79 +-------- google/api_core/rest_streaming_async.py | 122 ++++---------- tests/conftest.py | 70 ++++++++ tests/unit/test_rest_streaming.py | 76 ++------- tests/unit/test_rest_streaming_async.py | 206 +++++++++++++----------- 6 files changed, 238 insertions(+), 348 deletions(-) create mode 100644 tests/conftest.py diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py index 88bcb31b..03d7011e 100644 --- a/google/api_core/_rest_streaming_base.py +++ b/google/api_core/_rest_streaming_base.py @@ -1,4 +1,4 @@ -# Copyright 2021 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,13 +19,12 @@ from typing import Deque, Union import proto -import requests import google.protobuf.message from google.protobuf.json_format import Parse -class ResponseIterator: - """Iterator over REST API responses. +class BaseResponseIterator: + """Base Iterator over REST API responses. This class should not be used directly. Args: response (requests.Response): An API response object. @@ -38,13 +37,11 @@ class expected to be returned from an API. def __init__( self, - response: requests.Response, response_message_cls: Union[proto.Message, google.protobuf.message.Message], ): - self._response = response self._response_message_cls = response_message_cls # Inner iterator over HTTP response's content. - self._response_itr = self._response.iter_content(decode_unicode=True) + # self._response_itr = self._response.iter_content(decode_unicode=True) # Contains a list of JSON responses ready to be sent to user. self._ready_objs: Deque[str] = deque() # Current JSON response being built. @@ -57,10 +54,6 @@ def __init__( # Whether an escape symbol "\" was encountered. self._escape_next = False - def cancel(self): - """Cancel existing streaming operation.""" - self._response.close() - def _process_chunk(self, chunk: str): if self._level == 0: if chunk[0] != "[": @@ -104,29 +97,13 @@ def _process_chunk(self, chunk: str): self._obj += char self._escape_next = not self._escape_next if char == "\\" else False - def __next__(self): - while not self._ready_objs: - try: - chunk = next(self._response_itr) - self._process_chunk(chunk) - except StopIteration as e: - if self._level > 0: - raise ValueError("Unfinished stream: %s" % self._obj) - raise e - return self._grab() - def _grab(self): # Add extra quotes to make json.loads happy. if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json( - self._ready_objs.popleft(), ignore_unknown_fields=True - ) + return self._response_message_cls.from_json(self._ready_objs.popleft()) elif issubclass(self._response_message_cls, google.protobuf.message.Message): return Parse(self._ready_objs.popleft(), self._response_message_cls()) else: raise ValueError( "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." ) - - def __iter__(self): - return self diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 88bcb31b..98fdd85d 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -14,17 +14,15 @@ """Helpers for server-side streaming in REST.""" -from collections import deque -import string -from typing import Deque, Union +from typing import Union import proto import requests import google.protobuf.message -from google.protobuf.json_format import Parse +from google.api_core._rest_streaming_base import BaseResponseIterator -class ResponseIterator: +class ResponseIterator(BaseResponseIterator): """Iterator over REST API responses. Args: @@ -42,68 +40,16 @@ def __init__( response_message_cls: Union[proto.Message, google.protobuf.message.Message], ): self._response = response - self._response_message_cls = response_message_cls # Inner iterator over HTTP response's content. self._response_itr = self._response.iter_content(decode_unicode=True) - # Contains a list of JSON responses ready to be sent to user. - self._ready_objs: Deque[str] = deque() - # Current JSON response being built. - self._obj = "" - # Keeps track of the nesting level within a JSON object. - self._level = 0 - # Keeps track whether HTTP response is currently sending values - # inside of a string value. - self._in_string = False - # Whether an escape symbol "\" was encountered. - self._escape_next = False + super(ResponseIterator, self).__init__( + response_message_cls=response_message_cls + ) def cancel(self): """Cancel existing streaming operation.""" self._response.close() - def _process_chunk(self, chunk: str): - if self._level == 0: - if chunk[0] != "[": - raise ValueError( - "Can only parse array of JSON objects, instead got %s" % chunk - ) - for char in chunk: - if char == "{": - if self._level == 1: - # Level 1 corresponds to the outermost JSON object - # (i.e. the one we care about). - self._obj = "" - if not self._in_string: - self._level += 1 - self._obj += char - elif char == "}": - self._obj += char - if not self._in_string: - self._level -= 1 - if not self._in_string and self._level == 1: - self._ready_objs.append(self._obj) - elif char == '"': - # Helps to deal with an escaped quotes inside of a string. - if not self._escape_next: - self._in_string = not self._in_string - self._obj += char - elif char in string.whitespace: - if self._in_string: - self._obj += char - elif char == "[": - if self._level == 0: - self._level += 1 - else: - self._obj += char - elif char == "]": - if self._level == 1: - self._level -= 1 - else: - self._obj += char - else: - self._obj += char - self._escape_next = not self._escape_next if char == "\\" else False - def __next__(self): while not self._ready_objs: try: @@ -115,18 +61,5 @@ def __next__(self): raise e return self._grab() - def _grab(self): - # Add extra quotes to make json.loads happy. - if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json( - self._ready_objs.popleft(), ignore_unknown_fields=True - ) - elif issubclass(self._response_message_cls, google.protobuf.message.Message): - return Parse(self._ready_objs.popleft(), self._response_message_cls()) - else: - raise ValueError( - "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." - ) - def __iter__(self): return self diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index 88bcb31b..6e9c87f3 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -1,4 +1,4 @@ -# Copyright 2021 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,121 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Helpers for server-side streaming in REST.""" +"""Helpers for server-side asynchronous streaming in REST.""" -from collections import deque -import string -from typing import Deque, Union +from typing import Union import proto -import requests +import google.auth.aio.transport import google.protobuf.message -from google.protobuf.json_format import Parse +from google.api_core._rest_streaming_base import BaseResponseIterator -class ResponseIterator: - """Iterator over REST API responses. +class AsyncResponseIterator(BaseResponseIterator): + """Asynchronous Iterator over REST API responses. Args: - response (requests.Response): An API response object. + response (google.auth.aio.transport.Response): An API response object. response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response class expected to be returned from an API. Raises: - ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + ValueError: + - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + - If `response` is not an instance of `google.auth.aio.transport.aiohttp.Response`. """ def __init__( self, - response: requests.Response, + response: google.auth.aio.transport.Response, response_message_cls: Union[proto.Message, google.protobuf.message.Message], ): self._response = response - self._response_message_cls = response_message_cls - # Inner iterator over HTTP response's content. - self._response_itr = self._response.iter_content(decode_unicode=True) - # Contains a list of JSON responses ready to be sent to user. - self._ready_objs: Deque[str] = deque() - # Current JSON response being built. - self._obj = "" - # Keeps track of the nesting level within a JSON object. - self._level = 0 - # Keeps track whether HTTP response is currently sending values - # inside of a string value. - self._in_string = False - # Whether an escape symbol "\" was encountered. - self._escape_next = False + self._chunk_size = 1024 + self._response_itr = self._response.content().__aiter__() + super(AsyncResponseIterator, self).__init__( + response_message_cls=response_message_cls + ) - def cancel(self): + async def cancel(self): """Cancel existing streaming operation.""" - self._response.close() + await self._response.close() - def _process_chunk(self, chunk: str): - if self._level == 0: - if chunk[0] != "[": - raise ValueError( - "Can only parse array of JSON objects, instead got %s" % chunk - ) - for char in chunk: - if char == "{": - if self._level == 1: - # Level 1 corresponds to the outermost JSON object - # (i.e. the one we care about). - self._obj = "" - if not self._in_string: - self._level += 1 - self._obj += char - elif char == "}": - self._obj += char - if not self._in_string: - self._level -= 1 - if not self._in_string and self._level == 1: - self._ready_objs.append(self._obj) - elif char == '"': - # Helps to deal with an escaped quotes inside of a string. - if not self._escape_next: - self._in_string = not self._in_string - self._obj += char - elif char in string.whitespace: - if self._in_string: - self._obj += char - elif char == "[": - if self._level == 0: - self._level += 1 - else: - self._obj += char - elif char == "]": - if self._level == 1: - self._level -= 1 - else: - self._obj += char - else: - self._obj += char - self._escape_next = not self._escape_next if char == "\\" else False - - def __next__(self): + async def __anext__(self): while not self._ready_objs: - try: - chunk = next(self._response_itr) + try: + chunk = await self._response_itr.__anext__() + chunk = chunk.decode("utf-8") self._process_chunk(chunk) - except StopIteration as e: + except StopAsyncIteration as e: if self._level > 0: - raise ValueError("Unfinished stream: %s" % self._obj) + raise ValueError("i Unfinished stream: %s" % self._obj) + raise e + except ValueError as e: raise e return self._grab() - def _grab(self): - # Add extra quotes to make json.loads happy. - if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json( - self._ready_objs.popleft(), ignore_unknown_fields=True - ) - elif issubclass(self._response_message_cls, google.protobuf.message.Message): - return Parse(self._ready_objs.popleft(), self._response_message_cls()) - else: - raise ValueError( - "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." - ) - - def __iter__(self): + def __aiter__(self): return self + + async def __aexit__(self, exc_type, exc, tb): + """Cancel existing async streaming operation.""" + await self._response.close() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..e74cce31 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,70 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import List + +import proto + +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf.json_format import MessageToJson + + +class Genre(proto.Enum): + GENRE_UNSPECIFIED = 0 + CLASSICAL = 1 + JAZZ = 2 + ROCK = 3 + + +class Composer(proto.Message): + given_name = proto.Field(proto.STRING, number=1) + family_name = proto.Field(proto.STRING, number=2) + relateds = proto.RepeatedField(proto.STRING, number=3) + indices = proto.MapField(proto.STRING, proto.STRING, number=4) + + +class Song(proto.Message): + composer = proto.Field(Composer, number=1) + title = proto.Field(proto.STRING, number=2) + lyrics = proto.Field(proto.STRING, number=3) + year = proto.Field(proto.INT32, number=4) + genre = proto.Field(Genre, number=5) + is_five_mins_longer = proto.Field(proto.BOOL, number=6) + score = proto.Field(proto.DOUBLE, number=7) + likes = proto.Field(proto.INT64, number=8) + duration = proto.Field(duration_pb2.Duration, number=9) + date_added = proto.Field(timestamp_pb2.Timestamp, number=10) + + +class EchoResponse(proto.Message): + content = proto.Field(proto.STRING, number=1) + + +def parse_responses(response_message_cls, responses: List[proto.Message]) -> bytes: + # json.dumps returns a string surrounded with quotes that need to be stripped + # in order to be an actual JSON. + json_responses = [ + ( + response_message_cls.to_json(r).strip('"') + if issubclass(response_message_cls, proto.Message) + else MessageToJson(r).strip('"') + ) + for r in responses + ] + logging.info(f"Sending JSON stream: {json_responses}") + ret_val = "[{}]".format(",".join(json_responses)) + return bytes(ret_val, "utf-8") diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 0f2b3b32..b215d02b 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -23,12 +23,11 @@ import pytest import requests -from google.api_core import rest_streaming +from google.api_core import rest_streaming_new from google.api import http_pb2 from google.api import httpbody_pb2 -from google.protobuf import duration_pb2 -from google.protobuf import timestamp_pb2 -from google.protobuf.json_format import MessageToJson + +from ..conftest import Composer, Song, EchoResponse, parse_responses __protobuf__ = proto.module(package=__name__) @@ -37,37 +36,6 @@ random.seed(SEED) -class Genre(proto.Enum): - GENRE_UNSPECIFIED = 0 - CLASSICAL = 1 - JAZZ = 2 - ROCK = 3 - - -class Composer(proto.Message): - given_name = proto.Field(proto.STRING, number=1) - family_name = proto.Field(proto.STRING, number=2) - relateds = proto.RepeatedField(proto.STRING, number=3) - indices = proto.MapField(proto.STRING, proto.STRING, number=4) - - -class Song(proto.Message): - composer = proto.Field(Composer, number=1) - title = proto.Field(proto.STRING, number=2) - lyrics = proto.Field(proto.STRING, number=3) - year = proto.Field(proto.INT32, number=4) - genre = proto.Field(Genre, number=5) - is_five_mins_longer = proto.Field(proto.BOOL, number=6) - score = proto.Field(proto.DOUBLE, number=7) - likes = proto.Field(proto.INT64, number=8) - duration = proto.Field(duration_pb2.Duration, number=9) - date_added = proto.Field(timestamp_pb2.Timestamp, number=10) - - -class EchoResponse(proto.Message): - content = proto.Field(proto.STRING, number=1) - - class ResponseMock(requests.Response): class _ResponseItr: def __init__(self, _response_bytes: bytes, random_split=False): @@ -97,30 +65,18 @@ def __init__( self._random_split = random_split self._response_message_cls = response_cls - def _parse_responses(self, responses: List[proto.Message]) -> bytes: - # json.dumps returns a string surrounded with quotes that need to be stripped - # in order to be an actual JSON. - json_responses = [ - ( - self._response_message_cls.to_json(r).strip('"') - if issubclass(self._response_message_cls, proto.Message) - else MessageToJson(r).strip('"') - ) - for r in responses - ] - logging.info(f"Sending JSON stream: {json_responses}") - ret_val = "[{}]".format(",".join(json_responses)) - return bytes(ret_val, "utf-8") - def close(self): raise NotImplementedError() def iter_content(self, *args, **kwargs): return self._ResponseItr( - self._parse_responses(self._responses), + self._parse_responses(), random_split=self._random_split, ) + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) + @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", @@ -140,7 +96,7 @@ def test_next_simple(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -177,7 +133,7 @@ def test_next_nested(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -210,7 +166,7 @@ def test_next_stress(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -271,7 +227,7 @@ def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_p resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -281,7 +237,7 @@ def test_next_not_array(response_type): ResponseMock, "iter_content", return_value=iter('{"hello": 0}') ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) mock_method.assert_called_once() @@ -291,7 +247,7 @@ def test_next_not_array(response_type): def test_cancel(response_type): with patch.object(ResponseMock, "close", return_value=None) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) itr.cancel() mock_method.assert_called_once() @@ -310,7 +266,7 @@ def test_check_buffer(response_type, return_value): return_value=return_value, ): resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) next(itr) @@ -322,7 +278,7 @@ def test_next_html(response_type): ResponseMock, "iter_content", return_value=iter("") ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_new.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) mock_method.assert_called_once() @@ -333,7 +289,7 @@ class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) - response_iterator = rest_streaming.ResponseIterator(resp, SomeClass) + response_iterator = rest_streaming_new.ResponseIterator(resp, SomeClass) with pytest.raises( ValueError, match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", diff --git a/tests/unit/test_rest_streaming_async.py b/tests/unit/test_rest_streaming_async.py index 0f2b3b32..e0ba799a 100644 --- a/tests/unit/test_rest_streaming_async.py +++ b/tests/unit/test_rest_streaming_async.py @@ -1,4 +1,4 @@ -# Copyright 2021 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,75 +16,52 @@ import logging import random import time -from typing import List -from unittest.mock import patch +from typing import List, AsyncIterator +import mock import proto import pytest -import requests -from google.api_core import rest_streaming +from google.api_core import rest_streaming_async_new from google.api import http_pb2 from google.api import httpbody_pb2 -from google.protobuf import duration_pb2 -from google.protobuf import timestamp_pb2 -from google.protobuf.json_format import MessageToJson +from google.auth.aio.transport import Response +from ..conftest import Composer, Song, EchoResponse, parse_responses +# TODO (ohmayr): check if we need to log. __protobuf__ = proto.module(package=__name__) SEED = int(time.time()) -logging.info(f"Starting rest streaming tests with random seed: {SEED}") +logging.info(f"Starting async rest streaming tests with random seed: {SEED}") random.seed(SEED) -class Genre(proto.Enum): - GENRE_UNSPECIFIED = 0 - CLASSICAL = 1 - JAZZ = 2 - ROCK = 3 +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") -class Composer(proto.Message): - given_name = proto.Field(proto.STRING, number=1) - family_name = proto.Field(proto.STRING, number=2) - relateds = proto.RepeatedField(proto.STRING, number=3) - indices = proto.MapField(proto.STRING, proto.STRING, number=4) - - -class Song(proto.Message): - composer = proto.Field(Composer, number=1) - title = proto.Field(proto.STRING, number=2) - lyrics = proto.Field(proto.STRING, number=3) - year = proto.Field(proto.INT32, number=4) - genre = proto.Field(Genre, number=5) - is_five_mins_longer = proto.Field(proto.BOOL, number=6) - score = proto.Field(proto.DOUBLE, number=7) - likes = proto.Field(proto.INT64, number=8) - duration = proto.Field(duration_pb2.Duration, number=9) - date_added = proto.Field(timestamp_pb2.Timestamp, number=10) - - -class EchoResponse(proto.Message): - content = proto.Field(proto.STRING, number=1) - - -class ResponseMock(requests.Response): - class _ResponseItr: +class ResponseMock(Response): + class _ResponseItr(AsyncIterator[bytes]): def __init__(self, _response_bytes: bytes, random_split=False): self._responses_bytes = _response_bytes self._i = 0 self._random_split = random_split - def __next__(self): + def __aiter__(self): + return self + + async def __anext__(self): if self._i == len(self._responses_bytes): - raise StopIteration + raise StopAsyncIteration if self._random_split: n = random.randint(1, len(self._responses_bytes[self._i :])) else: n = 1 x = self._responses_bytes[self._i : self._i + n] self._i += n - return x.decode("utf-8") + return x def __init__( self, @@ -92,41 +69,42 @@ def __init__( response_cls, random_split=False, ): - super().__init__() self._responses = responses self._random_split = random_split self._response_message_cls = response_cls - def _parse_responses(self, responses: List[proto.Message]) -> bytes: - # json.dumps returns a string surrounded with quotes that need to be stripped - # in order to be an actual JSON. - json_responses = [ - ( - self._response_message_cls.to_json(r).strip('"') - if issubclass(self._response_message_cls, proto.Message) - else MessageToJson(r).strip('"') - ) - for r in responses - ] - logging.info(f"Sending JSON stream: {json_responses}") - ret_val = "[{}]".format(",".join(json_responses)) - return bytes(ret_val, "utf-8") - - def close(self): + @property + async def close(self): raise NotImplementedError() - def iter_content(self, *args, **kwargs): - return self._ResponseItr( - self._parse_responses(self._responses), - random_split=self._random_split, + async def content(self, chunk_size=None): + itr = self._ResponseItr( + self._parse_responses(), random_split=self._random_split ) + async for chunk in itr: + yield chunk + async def read(self): + raise NotImplementedError() + @property + async def headers(self): + raise NotImplementedError() + + @property + async def status_code(self): + raise NotImplementedError() + + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) + + +@pytest.mark.asyncio @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [(False, True), (False, False)], ) -def test_next_simple(random_split, resp_message_is_proto_plus): +async def test_next_simple(random_split, resp_message_is_proto_plus): if resp_message_is_proto_plus: response_type = EchoResponse responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] @@ -140,10 +118,14 @@ def test_next_simple(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) - assert list(itr) == responses + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + i = 0 + async for response in itr: + assert response == responses[i] + i += 1 +@pytest.mark.asyncio @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [ @@ -153,7 +135,7 @@ def test_next_simple(random_split, resp_message_is_proto_plus): (False, False), ], ) -def test_next_nested(random_split, resp_message_is_proto_plus): +async def test_next_nested(random_split, resp_message_is_proto_plus): if resp_message_is_proto_plus: response_type = Song responses = [ @@ -177,10 +159,15 @@ def test_next_nested(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) - assert list(itr) == responses + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + i = 0 + async for response in itr: + assert response == responses[i] + i += 1 + assert i == len(responses) +@pytest.mark.asyncio @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [ @@ -190,7 +177,7 @@ def test_next_nested(random_split, resp_message_is_proto_plus): (False, False), ], ) -def test_next_stress(random_split, resp_message_is_proto_plus): +async def test_next_stress(random_split, resp_message_is_proto_plus): n = 50 if resp_message_is_proto_plus: response_type = Song @@ -210,10 +197,15 @@ def test_next_stress(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) - assert list(itr) == responses + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + i = 0 + async for response in itr: + assert response == responses[i] + i += 1 + assert i == n +@pytest.mark.asyncio @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [ @@ -223,7 +215,9 @@ def test_next_stress(random_split, resp_message_is_proto_plus): (False, False), ], ) -def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus): +async def test_next_escaped_characters_in_string( + random_split, resp_message_is_proto_plus +): if resp_message_is_proto_plus: response_type = Song composer_with_relateds = Composer() @@ -271,31 +265,43 @@ def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_p resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming.ResponseIterator(resp, response_type) - assert list(itr) == responses + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + + i = 0 + async for response in itr: + assert response == responses[i] + i += 1 + assert i == len(responses) +@pytest.mark.asyncio @pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) -def test_next_not_array(response_type): - with patch.object( - ResponseMock, "iter_content", return_value=iter('{"hello": 0}') +async def test_next_not_array(response_type): + + data = '{"hello": 0}' + with mock.patch.object( + ResponseMock, "content", return_value=mock_async_gen(data) ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) with pytest.raises(ValueError): - next(itr) + await itr.__anext__() mock_method.assert_called_once() +@pytest.mark.asyncio @pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) -def test_cancel(response_type): - with patch.object(ResponseMock, "close", return_value=None) as mock_method: +async def test_cancel(response_type): + with mock.patch.object( + ResponseMock, "close", new_callable=mock.AsyncMock + ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) - itr.cancel() + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + await itr.cancel() mock_method.assert_called_once() +@pytest.mark.asyncio @pytest.mark.parametrize( "response_type,return_value", [ @@ -303,37 +309,43 @@ def test_cancel(response_type): (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")), ], ) -def test_check_buffer(response_type, return_value): - with patch.object( +async def test_check_buffer(response_type, return_value): + with mock.patch.object( ResponseMock, "_parse_responses", return_value=return_value, ): resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) with pytest.raises(ValueError): - next(itr) - next(itr) + await itr.__anext__() + await itr.__anext__() +@pytest.mark.asyncio @pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) -def test_next_html(response_type): - with patch.object( - ResponseMock, "iter_content", return_value=iter("") +async def test_next_html(response_type): + + data = "" + with mock.patch.object( + ResponseMock, "content", return_value=mock_async_gen(data) ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming.ResponseIterator(resp, response_type) + + itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) with pytest.raises(ValueError): - next(itr) + await itr.__anext__() mock_method.assert_called_once() -def test_invalid_response_class(): +@pytest.mark.asyncio +async def test_invalid_response_class(): class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) - response_iterator = rest_streaming.ResponseIterator(resp, SomeClass) + response_iterator = rest_streaming_async_new.AsyncResponseIterator(resp, SomeClass) + with pytest.raises( ValueError, match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", From 75ae7d7e7c745ef52fa17386af954e4998d3fe42 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Fri, 23 Aug 2024 15:40:30 +0000 Subject: [PATCH 08/24] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- google/api_core/rest_streaming_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index 6e9c87f3..6c225c6b 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -54,7 +54,7 @@ async def cancel(self): async def __anext__(self): while not self._ready_objs: - try: + try: chunk = await self._response_itr.__anext__() chunk = chunk.decode("utf-8") self._process_chunk(chunk) From db7cfb3c5b1bda0f603e0a08c7cb812017db5a0f Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:49:16 +0000 Subject: [PATCH 09/24] fix naming issue --- tests/unit/test_rest_streaming_async.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/unit/test_rest_streaming_async.py b/tests/unit/test_rest_streaming_async.py index e0ba799a..6fdfa9ca 100644 --- a/tests/unit/test_rest_streaming_async.py +++ b/tests/unit/test_rest_streaming_async.py @@ -22,7 +22,7 @@ import proto import pytest -from google.api_core import rest_streaming_async_new +from google.api_core import rest_streaming_async from google.api import http_pb2 from google.api import httpbody_pb2 from google.auth.aio.transport import Response @@ -118,7 +118,7 @@ async def test_next_simple(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) i = 0 async for response in itr: assert response == responses[i] @@ -159,7 +159,7 @@ async def test_next_nested(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) i = 0 async for response in itr: assert response == responses[i] @@ -197,7 +197,7 @@ async def test_next_stress(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) i = 0 async for response in itr: assert response == responses[i] @@ -265,7 +265,7 @@ async def test_next_escaped_characters_in_string( resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) i = 0 async for response in itr: @@ -283,7 +283,7 @@ async def test_next_not_array(response_type): ResponseMock, "content", return_value=mock_async_gen(data) ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) with pytest.raises(ValueError): await itr.__anext__() mock_method.assert_called_once() @@ -296,7 +296,7 @@ async def test_cancel(response_type): ResponseMock, "close", new_callable=mock.AsyncMock ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) await itr.cancel() mock_method.assert_called_once() @@ -316,7 +316,7 @@ async def test_check_buffer(response_type, return_value): return_value=return_value, ): resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) with pytest.raises(ValueError): await itr.__anext__() await itr.__anext__() @@ -332,7 +332,7 @@ async def test_next_html(response_type): ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_async_new.AsyncResponseIterator(resp, response_type) + itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) with pytest.raises(ValueError): await itr.__anext__() mock_method.assert_called_once() @@ -344,7 +344,7 @@ class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) - response_iterator = rest_streaming_async_new.AsyncResponseIterator(resp, SomeClass) + response_iterator = rest_streaming_async.AsyncResponseIterator(resp, SomeClass) with pytest.raises( ValueError, From 87eeca3248470973e6d0ab04d04b605f74b2351f Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 15:59:51 +0000 Subject: [PATCH 10/24] fix import module name --- tests/unit/test_rest_streaming.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index b215d02b..50beb047 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -23,7 +23,7 @@ import pytest import requests -from google.api_core import rest_streaming_new +from google.api_core import rest_streaming from google.api import http_pb2 from google.api import httpbody_pb2 @@ -96,7 +96,7 @@ def test_next_simple(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -133,7 +133,7 @@ def test_next_nested(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -166,7 +166,7 @@ def test_next_stress(random_split, resp_message_is_proto_plus): resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -227,7 +227,7 @@ def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_p resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @@ -237,7 +237,7 @@ def test_next_not_array(response_type): ResponseMock, "iter_content", return_value=iter('{"hello": 0}') ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) mock_method.assert_called_once() @@ -247,7 +247,7 @@ def test_next_not_array(response_type): def test_cancel(response_type): with patch.object(ResponseMock, "close", return_value=None) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) itr.cancel() mock_method.assert_called_once() @@ -266,7 +266,7 @@ def test_check_buffer(response_type, return_value): return_value=return_value, ): resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) next(itr) @@ -278,7 +278,7 @@ def test_next_html(response_type): ResponseMock, "iter_content", return_value=iter("") ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) - itr = rest_streaming_new.ResponseIterator(resp, response_type) + itr = rest_streaming.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) mock_method.assert_called_once() @@ -289,7 +289,7 @@ class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) - response_iterator = rest_streaming_new.ResponseIterator(resp, SomeClass) + response_iterator = rest_streaming.ResponseIterator(resp, SomeClass) with pytest.raises( ValueError, match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", From d6abddd547149ce2eb41d94347e5f353e97b2863 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 16:00:10 +0000 Subject: [PATCH 11/24] pull auth feature branch --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a9e01f49..52cbea6f 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,8 @@ "googleapis-common-protos >= 1.56.2, < 2.0.dev0", "protobuf>=3.19.5,<6.0.0.dev0,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", "proto-plus >= 1.22.3, <2.0.0dev", - "google-auth >= 2.14.1, < 3.0.dev0", + # TODO (ohmayr): Temporary change to make test cases pass. Revert before merging. + "google-auth @ git+ssh://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad", "requests >= 2.18.0, < 3.0.0.dev0", ] extras = { From a6a648df163fb4297933c1a8e70e0235c30b7502 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Aug 2024 16:02:17 +0000 Subject: [PATCH 12/24] revert setup file --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 52cbea6f..a9e01f49 100644 --- a/setup.py +++ b/setup.py @@ -32,8 +32,7 @@ "googleapis-common-protos >= 1.56.2, < 2.0.dev0", "protobuf>=3.19.5,<6.0.0.dev0,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", "proto-plus >= 1.22.3, <2.0.0dev", - # TODO (ohmayr): Temporary change to make test cases pass. Revert before merging. - "google-auth @ git+ssh://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad", + "google-auth >= 2.14.1, < 3.0.dev0", "requests >= 2.18.0, < 3.0.0.dev0", ] extras = { From 8d30b2d9918ec5cfc260d230d1dcc76d969a2fc2 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Sat, 24 Aug 2024 20:52:12 +0000 Subject: [PATCH 13/24] address PR comments --- google/api_core/_rest_streaming_base.py | 5 +- google/api_core/rest_streaming.py | 4 +- google/api_core/rest_streaming_async.py | 4 +- tests/{conftest.py => helpers.py} | 7 +-- tests/unit/test_rest_streaming.py | 9 ++-- tests/unit/test_rest_streaming_async.py | 66 ++++++++++++------------- 6 files changed, 47 insertions(+), 48 deletions(-) rename tests/{conftest.py => helpers.py} (92%) diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py index 03d7011e..406184d4 100644 --- a/google/api_core/_rest_streaming_base.py +++ b/google/api_core/_rest_streaming_base.py @@ -27,7 +27,6 @@ class BaseResponseIterator: """Base Iterator over REST API responses. This class should not be used directly. Args: - response (requests.Response): An API response object. response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response class expected to be returned from an API. @@ -40,8 +39,6 @@ def __init__( response_message_cls: Union[proto.Message, google.protobuf.message.Message], ): self._response_message_cls = response_message_cls - # Inner iterator over HTTP response's content. - # self._response_itr = self._response.iter_content(decode_unicode=True) # Contains a list of JSON responses ready to be sent to user. self._ready_objs: Deque[str] = deque() # Current JSON response being built. @@ -100,7 +97,7 @@ def _process_chunk(self, chunk: str): def _grab(self): # Add extra quotes to make json.loads happy. if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json(self._ready_objs.popleft()) + return self._response_message_cls.from_json(self._ready_objs.popleft(), ignore_unknown_fields=True) elif issubclass(self._response_message_cls, google.protobuf.message.Message): return Parse(self._ready_objs.popleft(), self._response_message_cls()) else: diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 98fdd85d..822f34f3 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -31,7 +31,9 @@ class ResponseIterator(BaseResponseIterator): class expected to be returned from an API. Raises: - ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + ValueError: + - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. + - If `response` is not an instance of `requests.Response`. """ def __init__( diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index 6c225c6b..11a9d792 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Helpers for server-side asynchronous streaming in REST.""" +"""Helpers for asynchronous server-side streaming in REST.""" from typing import Union @@ -33,7 +33,7 @@ class expected to be returned from an API. Raises: ValueError: - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. - - If `response` is not an instance of `google.auth.aio.transport.aiohttp.Response`. + - If `response` is not an instance of a subclass of `google.auth.aio.transport.Response`. """ def __init__( diff --git a/tests/conftest.py b/tests/helpers.py similarity index 92% rename from tests/conftest.py rename to tests/helpers.py index e74cce31..28ac20e5 100644 --- a/tests/conftest.py +++ b/tests/helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Helpers for tests""" import logging from typing import List @@ -59,11 +60,11 @@ def parse_responses(response_message_cls, responses: List[proto.Message]) -> byt # in order to be an actual JSON. json_responses = [ ( - response_message_cls.to_json(r).strip('"') + response_message_cls.to_json(response).strip('"') if issubclass(response_message_cls, proto.Message) - else MessageToJson(r).strip('"') + else MessageToJson(response).strip('"') ) - for r in responses + for response in responses ] logging.info(f"Sending JSON stream: {json_responses}") ret_val = "[{}]".format(",".join(json_responses)) diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 50beb047..006adef7 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -27,12 +27,12 @@ from google.api import http_pb2 from google.api import httpbody_pb2 -from ..conftest import Composer, Song, EchoResponse, parse_responses +from ..helpers import Composer, Song, EchoResponse, parse_responses __protobuf__ = proto.module(package=__name__) SEED = int(time.time()) -logging.info(f"Starting rest streaming tests with random seed: {SEED}") +logging.info(f"Starting sync rest streaming tests with random seed: {SEED}") random.seed(SEED) @@ -65,6 +65,9 @@ def __init__( self._random_split = random_split self._response_message_cls = response_cls + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) + def close(self): raise NotImplementedError() @@ -74,8 +77,6 @@ def iter_content(self, *args, **kwargs): random_split=self._random_split, ) - def _parse_responses(self): - return parse_responses(self._response_message_cls, self._responses) @pytest.mark.parametrize( diff --git a/tests/unit/test_rest_streaming_async.py b/tests/unit/test_rest_streaming_async.py index 6fdfa9ca..a8e9cec2 100644 --- a/tests/unit/test_rest_streaming_async.py +++ b/tests/unit/test_rest_streaming_async.py @@ -27,9 +27,9 @@ from google.api import httpbody_pb2 from google.auth.aio.transport import Response -from ..conftest import Composer, Song, EchoResponse, parse_responses +from ..helpers import Composer, Song, EchoResponse, parse_responses + -# TODO (ohmayr): check if we need to log. __protobuf__ = proto.module(package=__name__) SEED = int(time.time()) logging.info(f"Starting async rest streaming tests with random seed: {SEED}") @@ -46,21 +46,21 @@ class ResponseMock(Response): class _ResponseItr(AsyncIterator[bytes]): def __init__(self, _response_bytes: bytes, random_split=False): self._responses_bytes = _response_bytes - self._i = 0 + self._idx = 0 self._random_split = random_split def __aiter__(self): return self async def __anext__(self): - if self._i == len(self._responses_bytes): + if self._idx >= len(self._responses_bytes): raise StopAsyncIteration if self._random_split: - n = random.randint(1, len(self._responses_bytes[self._i :])) + n = random.randint(1, len(self._responses_bytes[self._idx :])) else: n = 1 - x = self._responses_bytes[self._i : self._i + n] - self._i += n + x = self._responses_bytes[self._idx : self._idx + n] + self._idx += n return x def __init__( @@ -73,7 +73,17 @@ def __init__( self._random_split = random_split self._response_message_cls = response_cls + def _parse_responses(self): + return parse_responses(self._response_message_cls, self._responses) + + @property + async def headers(self): + raise NotImplementedError() + @property + async def status_code(self): + raise NotImplementedError() + async def close(self): raise NotImplementedError() @@ -87,17 +97,6 @@ async def content(self, chunk_size=None): async def read(self): raise NotImplementedError() - @property - async def headers(self): - raise NotImplementedError() - - @property - async def status_code(self): - raise NotImplementedError() - - def _parse_responses(self): - return parse_responses(self._response_message_cls, self._responses) - @pytest.mark.asyncio @pytest.mark.parametrize( @@ -119,10 +118,10 @@ async def test_next_simple(random_split, resp_message_is_proto_plus): responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) - i = 0 + idx = 0 async for response in itr: - assert response == responses[i] - i += 1 + assert response == responses[idx] + idx += 1 @pytest.mark.asyncio @@ -160,11 +159,11 @@ async def test_next_nested(random_split, resp_message_is_proto_plus): responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) - i = 0 + idx = 0 async for response in itr: - assert response == responses[i] - i += 1 - assert i == len(responses) + assert response == responses[idx] + idx += 1 + assert idx == len(responses) @pytest.mark.asyncio @@ -198,11 +197,11 @@ async def test_next_stress(random_split, resp_message_is_proto_plus): responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) - i = 0 + idx = 0 async for response in itr: - assert response == responses[i] - i += 1 - assert i == n + assert response == responses[idx] + idx += 1 + assert idx == n @pytest.mark.asyncio @@ -266,12 +265,11 @@ async def test_next_escaped_characters_in_string( responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming_async.AsyncResponseIterator(resp, response_type) - - i = 0 + idx = 0 async for response in itr: - assert response == responses[i] - i += 1 - assert i == len(responses) + assert response == responses[idx] + idx += 1 + assert idx == len(responses) @pytest.mark.asyncio From 0b51b09f2313e8f54998430011925ef6b72a2630 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Sat, 24 Aug 2024 20:54:02 +0000 Subject: [PATCH 14/24] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- google/api_core/_rest_streaming_base.py | 4 +++- tests/unit/test_rest_streaming.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py index 406184d4..1a870985 100644 --- a/google/api_core/_rest_streaming_base.py +++ b/google/api_core/_rest_streaming_base.py @@ -97,7 +97,9 @@ def _process_chunk(self, chunk: str): def _grab(self): # Add extra quotes to make json.loads happy. if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json(self._ready_objs.popleft(), ignore_unknown_fields=True) + return self._response_message_cls.from_json( + self._ready_objs.popleft(), ignore_unknown_fields=True + ) elif issubclass(self._response_message_cls, google.protobuf.message.Message): return Parse(self._ready_objs.popleft(), self._response_message_cls()) else: diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 006adef7..9ab90d7b 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -78,7 +78,6 @@ def iter_content(self, *args, **kwargs): ) - @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [(False, True), (False, False)], From 7d8f1e1dcaf307001b9ba6a5ea6aa0d3894701c5 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Sat, 24 Aug 2024 21:01:05 +0000 Subject: [PATCH 15/24] run black --- google/api_core/_rest_streaming_base.py | 4 +++- tests/unit/test_rest_streaming.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py index 406184d4..1a870985 100644 --- a/google/api_core/_rest_streaming_base.py +++ b/google/api_core/_rest_streaming_base.py @@ -97,7 +97,9 @@ def _process_chunk(self, chunk: str): def _grab(self): # Add extra quotes to make json.loads happy. if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json(self._ready_objs.popleft(), ignore_unknown_fields=True) + return self._response_message_cls.from_json( + self._ready_objs.popleft(), ignore_unknown_fields=True + ) elif issubclass(self._response_message_cls, google.protobuf.message.Message): return Parse(self._ready_objs.popleft(), self._response_message_cls()) else: diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 006adef7..9ab90d7b 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -78,7 +78,6 @@ def iter_content(self, *args, **kwargs): ) - @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [(False, True), (False, False)], From d7930695e56418a635cf1906cc318b275c2038a5 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Aug 2024 05:35:03 +0000 Subject: [PATCH 16/24] address PR comments --- google/api_core/_rest_streaming_base.py | 32 +++++++++++-------- google/api_core/rest_streaming_async.py | 10 +++++- noxfile.py | 13 +++++++- .../test_rest_streaming_async.py | 27 ++++++++++++---- tests/helpers.py | 4 +-- tests/unit/test_rest_streaming.py | 3 +- 6 files changed, 64 insertions(+), 25 deletions(-) rename tests/{unit => asyncio}/test_rest_streaming_async.py (95%) diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py index 1a870985..89b15f10 100644 --- a/google/api_core/_rest_streaming_base.py +++ b/google/api_core/_rest_streaming_base.py @@ -17,6 +17,7 @@ from collections import deque import string from typing import Deque, Union +import types import proto import google.protobuf.message @@ -51,6 +52,24 @@ def __init__( # Whether an escape symbol "\" was encountered. self._escape_next = False + if issubclass(self._response_message_cls, proto.Message): + + def grab(this): + return this._response_message_cls.from_json( + self._ready_objs.popleft(), ignore_unknown_fields=True + ) + + elif issubclass(self._response_message_cls, google.protobuf.message.Message): + + def grab(this): + return Parse(this._ready_objs.popleft(), self._response_message_cls()) + + else: + raise ValueError( + "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." + ) + self._grab = types.MethodType(grab, self) + def _process_chunk(self, chunk: str): if self._level == 0: if chunk[0] != "[": @@ -93,16 +112,3 @@ def _process_chunk(self, chunk: str): else: self._obj += char self._escape_next = not self._escape_next if char == "\\" else False - - def _grab(self): - # Add extra quotes to make json.loads happy. - if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json( - self._ready_objs.popleft(), ignore_unknown_fields=True - ) - elif issubclass(self._response_message_cls, google.protobuf.message.Message): - return Parse(self._ready_objs.popleft(), self._response_message_cls()) - else: - raise ValueError( - "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." - ) diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index 11a9d792..9d1c0f82 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -17,7 +17,15 @@ from typing import Union import proto -import google.auth.aio.transport + +try: + import google.auth.aio.transport +except ImportError: + # TODO (ohmayr): Update this version once auth work is released. + raise ValueError( + "google-auth>=2.3x.x is required to use asynchronous rest streaming." + ) + import google.protobuf.message from google.api_core._rest_streaming_base import BaseResponseIterator diff --git a/noxfile.py b/noxfile.py index a15795ea..3d4b042a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -109,7 +109,7 @@ def install_prerelease_dependencies(session, constraints_path): session.install(*other_deps) -def default(session, install_grpc=True, prerelease=False): +def default(session, install_grpc=True, prerelease=False, install_auth_aio=False): """Default unit test session. This is intended to be run **without** an interpreter set, so @@ -144,6 +144,11 @@ def default(session, install_grpc=True, prerelease=False): f"{constraints_dir}/constraints-{session.python}.txt", ) + if install_auth_aio: + session.install( + "google-auth @ git+ssh://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad" + ) + # Print out package versions of dependencies session.run( "python", "-c", "import google.protobuf; print(google.protobuf.__version__)" @@ -229,6 +234,12 @@ def unit_wo_grpc(session): default(session, install_grpc=False) +@nox.session(python=PYTHON_VERSIONS) +def unit_with_auth_aio(session): + """Run the unit test suite with google.auth.aio installed""" + default(session, install_auth_aio=True) + + @nox.session(python=DEFAULT_PYTHON_VERSION) def lint_setup_py(session): """Verify that setup.py is valid (including RST check).""" diff --git a/tests/unit/test_rest_streaming_async.py b/tests/asyncio/test_rest_streaming_async.py similarity index 95% rename from tests/unit/test_rest_streaming_async.py rename to tests/asyncio/test_rest_streaming_async.py index a8e9cec2..fe465a30 100644 --- a/tests/unit/test_rest_streaming_async.py +++ b/tests/asyncio/test_rest_streaming_async.py @@ -12,20 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: set random.seed explicitly in each test function. +# See related issue: https://github.com/googleapis/python-api-core/issues/689. + +import pytest # noqa: I202 +import mock + import datetime import logging import random import time from typing import List, AsyncIterator -import mock import proto -import pytest + +try: + from google.auth.aio.transport import Response + + AUTH_AIO_INSTALLED = True +except ImportError: + AUTH_AIO_INSTALLED = False + +if not AUTH_AIO_INSTALLED: # pragma: NO COVER + pytest.skip( + "google-auth>=2.3x.x is required to use asynchronous rest streaming.", + allow_module_level=True, + ) from google.api_core import rest_streaming_async from google.api import http_pb2 from google.api import httpbody_pb2 -from google.auth.aio.transport import Response + from ..helpers import Composer, Song, EchoResponse, parse_responses @@ -342,10 +359,8 @@ class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) - response_iterator = rest_streaming_async.AsyncResponseIterator(resp, SomeClass) - with pytest.raises( ValueError, match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", ): - response_iterator._grab() + rest_streaming_async.AsyncResponseIterator(resp, SomeClass) diff --git a/tests/helpers.py b/tests/helpers.py index 28ac20e5..3429d511 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -55,7 +55,7 @@ class EchoResponse(proto.Message): content = proto.Field(proto.STRING, number=1) -def parse_responses(response_message_cls, responses: List[proto.Message]) -> bytes: +def parse_responses(response_message_cls, all_responses: List[proto.Message]) -> bytes: # json.dumps returns a string surrounded with quotes that need to be stripped # in order to be an actual JSON. json_responses = [ @@ -64,7 +64,7 @@ def parse_responses(response_message_cls, responses: List[proto.Message]) -> byt if issubclass(response_message_cls, proto.Message) else MessageToJson(response).strip('"') ) - for response in responses + for response in all_responses ] logging.info(f"Sending JSON stream: {json_responses}") ret_val = "[{}]".format(",".join(json_responses)) diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 9ab90d7b..0f998dfe 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -289,9 +289,8 @@ class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) - response_iterator = rest_streaming.ResponseIterator(resp, SomeClass) with pytest.raises( ValueError, match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", ): - response_iterator._grab() + rest_streaming.ResponseIterator(resp, SomeClass) From 70d1cdbe3d13ef234b25d9b2178072a237d17e69 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Aug 2024 06:33:28 +0000 Subject: [PATCH 17/24] update nox coverage --- noxfile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/noxfile.py b/noxfile.py index 3d4b042a..d5cd2441 100644 --- a/noxfile.py +++ b/noxfile.py @@ -38,6 +38,7 @@ "unit", "unit_grpc_gcp", "unit_wo_grpc", + "unit_with_auth_aio", "cover", "pytype", "mypy", From ee3a5acb38d46d2c22c09b3fec01ef9efb148444 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Aug 2024 06:39:12 +0000 Subject: [PATCH 18/24] address PR comments --- .github/workflows/unittest.yml | 2 +- google/api_core/_rest_streaming_base.py | 38 ++++++++++++++----------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 34d29b7c..acd7000b 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps"] + option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps", "unit_with_auth_aio"] python: - "3.7" - "3.8" diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py index 89b15f10..3bc87a96 100644 --- a/google/api_core/_rest_streaming_base.py +++ b/google/api_core/_rest_streaming_base.py @@ -52,23 +52,7 @@ def __init__( # Whether an escape symbol "\" was encountered. self._escape_next = False - if issubclass(self._response_message_cls, proto.Message): - - def grab(this): - return this._response_message_cls.from_json( - self._ready_objs.popleft(), ignore_unknown_fields=True - ) - - elif issubclass(self._response_message_cls, google.protobuf.message.Message): - - def grab(this): - return Parse(this._ready_objs.popleft(), self._response_message_cls()) - - else: - raise ValueError( - "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." - ) - self._grab = types.MethodType(grab, self) + self._grab = types.MethodType(self._create_grab(), self) def _process_chunk(self, chunk: str): if self._level == 0: @@ -112,3 +96,23 @@ def _process_chunk(self, chunk: str): else: self._obj += char self._escape_next = not self._escape_next if char == "\\" else False + + def _create_grab(self): + if issubclass(self._response_message_cls, proto.Message): + + def grab(this): + return this._response_message_cls.from_json( + this._ready_objs.popleft(), ignore_unknown_fields=True + ) + + return grab + elif issubclass(self._response_message_cls, google.protobuf.message.Message): + + def grab(this): + return Parse(this._ready_objs.popleft(), this._response_message_cls()) + + return grab + else: + raise ValueError( + "Response message class must be a subclass of proto.Message or google.protobuf.message.Message." + ) From 5a6e48860f6923874ff4cc2c031e294e5fecde47 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Aug 2024 06:40:40 +0000 Subject: [PATCH 19/24] fix nox session name in workflow --- .github/workflows/unittest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index acd7000b..504fdc09 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps", "unit_with_auth_aio"] + option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps", "_with_auth_aio"] python: - "3.7" - "3.8" From f2180daae5ed06930bee520c779a8826188d74e3 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Aug 2024 06:45:39 +0000 Subject: [PATCH 20/24] use https for remote repo --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index d5cd2441..3fc4a722 100644 --- a/noxfile.py +++ b/noxfile.py @@ -147,7 +147,7 @@ def default(session, install_grpc=True, prerelease=False, install_auth_aio=False if install_auth_aio: session.install( - "google-auth @ git+ssh://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad" + "google-auth @ git+https://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad" ) # Print out package versions of dependencies From bc811e5c88d2f7fc9e086b61f1896b249d3ca869 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Aug 2024 07:06:00 +0000 Subject: [PATCH 21/24] add context manager methods --- google/api_core/rest_streaming_async.py | 5 ++++- tests/asyncio/test_rest_streaming_async.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index 9d1c0f82..b24630b3 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -20,7 +20,7 @@ try: import google.auth.aio.transport -except ImportError: +except ImportError: # pragma: NO COVER # TODO (ohmayr): Update this version once auth work is released. raise ValueError( "google-auth>=2.3x.x is required to use asynchronous rest streaming." @@ -56,6 +56,9 @@ def __init__( response_message_cls=response_message_cls ) + async def __aenter__(self): + return self + async def cancel(self): """Cancel existing streaming operation.""" await self._response.close() diff --git a/tests/asyncio/test_rest_streaming_async.py b/tests/asyncio/test_rest_streaming_async.py index fe465a30..1f6f97de 100644 --- a/tests/asyncio/test_rest_streaming_async.py +++ b/tests/asyncio/test_rest_streaming_async.py @@ -54,7 +54,7 @@ async def mock_async_gen(data, chunk_size=1): - for i in range(0, len(data)): + for i in range(0, len(data)): # pragma: NO COVER chunk = data[i : i + chunk_size] yield chunk.encode("utf-8") @@ -316,6 +316,18 @@ async def test_cancel(response_type): mock_method.assert_called_once() +@pytest.mark.asyncio +@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) +async def test_iterator_as_context_manager(response_type): + with mock.patch.object( + ResponseMock, "close", new_callable=mock.AsyncMock + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=response_type) + async with rest_streaming_async.AsyncResponseIterator(resp, response_type): + pass + mock_method.assert_called_once() + + @pytest.mark.asyncio @pytest.mark.parametrize( "response_type,return_value", From fc74ae02456e26cc9c7f8137a3fee70cbef4550c Mon Sep 17 00:00:00 2001 From: ohmayr Date: Wed, 11 Sep 2024 05:46:57 +0000 Subject: [PATCH 22/24] address PR comments --- google/api_core/rest_streaming.py | 1 - google/api_core/rest_streaming_async.py | 1 - 2 files changed, 2 deletions(-) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 822f34f3..84aa270c 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -33,7 +33,6 @@ class expected to be returned from an API. Raises: ValueError: - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. - - If `response` is not an instance of `requests.Response`. """ def __init__( diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index b24630b3..acd5af6e 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -41,7 +41,6 @@ class expected to be returned from an API. Raises: ValueError: - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`. - - If `response` is not an instance of a subclass of `google.auth.aio.transport.Response`. """ def __init__( From cee06abe8b980997d624db69d5421fd8c9da759d Mon Sep 17 00:00:00 2001 From: ohmayr Date: Wed, 18 Sep 2024 15:07:34 +0000 Subject: [PATCH 23/24] update auth error versions --- google/api_core/rest_streaming_async.py | 3 +-- tests/asyncio/test_rest_streaming_async.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index acd5af6e..993a6c14 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -21,9 +21,8 @@ try: import google.auth.aio.transport except ImportError: # pragma: NO COVER - # TODO (ohmayr): Update this version once auth work is released. raise ValueError( - "google-auth>=2.3x.x is required to use asynchronous rest streaming." + "google-auth>=2.35.0 is required to use asynchronous rest streaming." ) import google.protobuf.message diff --git a/tests/asyncio/test_rest_streaming_async.py b/tests/asyncio/test_rest_streaming_async.py index 1f6f97de..35820de6 100644 --- a/tests/asyncio/test_rest_streaming_async.py +++ b/tests/asyncio/test_rest_streaming_async.py @@ -35,7 +35,7 @@ if not AUTH_AIO_INSTALLED: # pragma: NO COVER pytest.skip( - "google-auth>=2.3x.x is required to use asynchronous rest streaming.", + "google-auth>=2.35.0 is required to use asynchronous rest streaming.", allow_module_level=True, ) From 750d4cb8c87f02bf91761d0b0d9541d672972a8b Mon Sep 17 00:00:00 2001 From: ohmayr Date: Wed, 18 Sep 2024 15:11:54 +0000 Subject: [PATCH 24/24] update import error --- google/api_core/rest_streaming_async.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py index 993a6c14..d1f996f6 100644 --- a/google/api_core/rest_streaming_async.py +++ b/google/api_core/rest_streaming_async.py @@ -20,10 +20,10 @@ try: import google.auth.aio.transport -except ImportError: # pragma: NO COVER - raise ValueError( +except ImportError as e: # pragma: NO COVER + raise ImportError( "google-auth>=2.35.0 is required to use asynchronous rest streaming." - ) + ) from e import google.protobuf.message from google.api_core._rest_streaming_base import BaseResponseIterator