Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: iterator for processing JSON responses in REST streaming. #317

Merged
merged 11 commits into from
Jan 10, 2022
114 changes: 114 additions & 0 deletions google/api_core/rest_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for server-side streaming in REST."""

from collections import deque
import string
from typing import Deque

import requests


class ResponseIterator:
"""Iterator over REST API responses.

Args:
response (requests.Response): An API response object.
response_message_cls (Callable[proto.Message]): A proto
class expected to be returned from an API.
"""

def __init__(self, response: requests.Response, response_message_cls):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
self._response_itr = self._response.iter_content(decode_unicode=True)
# Contains a list of JSON responses ready to be sent to user.
self._ready_objs: Deque[str] = deque()
# Current JSON response being built.
self._obj = ""
# Keeps track of the nesting level within a JSON object.
self._level = 0
# Keeps track whether HTTP response is currently sending values
# inside of a string value.
self._in_string = False
# Whether an escape symbol "\" was encountered.
self._escape_next = False

def cancel(self):
"""Cancel existing streaming operation.
"""
self._response.close()

def _process_chunk(self, chunk: str):
atulep marked this conversation as resolved.
Show resolved Hide resolved
if self._level == 0:
if chunk[0] != "[":
raise ValueError(
"Can only parse array of JSON objects, instead got %s" % chunk
)
for char in chunk:
if char == "{":
if self._level == 1:
# Level 1 corresponds to the outermost JSON object
# (i.e. the one we care about).
self._obj = ""
if not self._in_string:
self._level += 1
self._obj += char
elif char == "}":
self._obj += char
if not self._in_string:
self._level -= 1
if not self._in_string and self._level == 1:
self._ready_objs.append(self._obj)
elif char == '"':
# Helps to deal with an escaped quotes inside of a string.
if not self._escape_next:
self._in_string = not self._in_string
self._obj += char
elif char in string.whitespace:
if self._in_string:
self._obj += char
elif char == "[":
if self._level == 0:
self._level += 1
else:
self._obj += char
elif char == "]":
if self._level == 1:
self._level -= 1
else:
self._obj += char
else:
self._obj += char
software-dov marked this conversation as resolved.
Show resolved Hide resolved
self._escape_next = not self._escape_next if char == "\\" else False

def __next__(self):
while not self._ready_objs:
try:
chunk = next(self._response_itr)
self._process_chunk(chunk)
except StopIteration as e:
if self._level > 0:
raise ValueError("Unfinished stream: %s" % self._obj)
raise e
return self._grab()

def _grab(self):
# Add extra quotes to make json.loads happy.
return self._response_message_cls.from_json(self._ready_objs.popleft())

def __iter__(self):
return self
211 changes: 211 additions & 0 deletions tests/unit/test_rest_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import random
import time
from typing import List
from unittest.mock import patch

import proto
import pytest
import requests

from google.api_core import rest_streaming
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2


SEED = int(time.time())
logging.info(f"Starting rest streaming tests with random seed: {SEED}")
random.seed(SEED)


class Genre(proto.Enum):
GENRE_UNSPECIFIED = 0
CLASSICAL = 1
JAZZ = 2
ROCK = 3


class Composer(proto.Message):
given_name = proto.Field(proto.STRING, number=1)
family_name = proto.Field(proto.STRING, number=2)
relateds = proto.RepeatedField(proto.STRING, number=3)
indices = proto.MapField(proto.STRING, proto.STRING, number=4)

atulep marked this conversation as resolved.
Show resolved Hide resolved

class Song(proto.Message):
software-dov marked this conversation as resolved.
Show resolved Hide resolved
composer = proto.Field(Composer, number=1)
title = proto.Field(proto.STRING, number=2)
lyrics = proto.Field(proto.STRING, number=3)
year = proto.Field(proto.INT32, number=4)
genre = proto.Field(Genre, number=5)
is_five_mins_longer = proto.Field(proto.BOOL, number=6)
score = proto.Field(proto.DOUBLE, number=7)
likes = proto.Field(proto.INT64, number=8)
duration = proto.Field(duration_pb2.Duration, number=9)
date_added = proto.Field(timestamp_pb2.Timestamp, number=10)


class EchoResponse(proto.Message):
content = proto.Field(proto.STRING, number=1)


class ResponseMock(requests.Response):
class _ResponseItr:
def __init__(self, _response_bytes: bytes, random_split=False):
self._responses_bytes = _response_bytes
self._i = 0
self._random_split = random_split

def __next__(self):
if self._i == len(self._responses_bytes):
raise StopIteration
if self._random_split:
n = random.randint(1, len(self._responses_bytes[self._i :]))
atulep marked this conversation as resolved.
Show resolved Hide resolved
else:
n = 1
x = self._responses_bytes[self._i : self._i + n]
self._i += n
return x.decode("utf-8")

def __init__(
self, responses: List[proto.Message], response_cls, random_split=False,
):
super().__init__()
self._responses = responses
self._random_split = random_split
self._response_message_cls = response_cls

def _parse_responses(self, responses: List[proto.Message]) -> bytes:
# json.dumps returns a string surrounded with quotes that need to be stripped
# in order to be an actual JSON.
json_responses = [
self._response_message_cls.to_json(r).strip('"') for r in responses
]
logging.info(f"Sending JSON stream: {json_responses}")
ret_val = "[{}]".format(",".join(json_responses))
return bytes(ret_val, "utf-8")

def close(self):
raise NotImplementedError()

def iter_content(self, *args, **kwargs):
return self._ResponseItr(
self._parse_responses(self._responses), random_split=self._random_split,
)


atulep marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("random_split", [False])
def test_next_simple(random_split):
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=EchoResponse
)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_nested(random_split):
responses = [
Song(title="some song", composer=Composer(given_name="some name")),
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
)
itr = rest_streaming.ResponseIterator(resp, Song)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_stress(random_split):
n = 50
responses = [
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
for i in range(n)
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
)
itr = rest_streaming.ResponseIterator(resp, Song)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_escaped_characters_in_string(random_split):
software-dov marked this conversation as resolved.
Show resolved Hide resolved
composer_with_relateds = Composer()
relateds = ["Artist A", "Artist B"]
composer_with_relateds.relateds = relateds

responses = [
Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
Song(
title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
),
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
)
itr = rest_streaming.ResponseIterator(resp, Song)
assert list(itr) == responses


def test_next_not_array():
with patch.object(
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()


def test_cancel():
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
itr.cancel()
mock_method.assert_called_once()


def test_check_buffer():
with patch.object(
ResponseMock,
"_parse_responses",
return_value=bytes('[{"content": "hello"}, {', "utf-8"),
):
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
with pytest.raises(ValueError):
next(itr)
next(itr)


def test_next_html():
with patch.object(
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()