Skip to content

Commit

Permalink
[Python][RRIO] Call PTransform with setup teardown (#29585)
Browse files Browse the repository at this point in the history
* rrio call ptransform with setup teardown

* add main

* add SECS suffix

* update names
  • Loading branch information
riteshghorse authored Dec 8, 2023
1 parent 209b095 commit a7976f5
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 17 deletions.
181 changes: 167 additions & 14 deletions sdks/python/apache_beam/io/requestresponseio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@

"""``PTransform`` for reading from and writing to Web APIs."""
import abc
import concurrent.futures
import contextlib
import logging
import sys
from typing import Generic
from typing import Optional
from typing import TypeVar

import apache_beam as beam
from apache_beam.pvalue import PCollection

RequestT = TypeVar('RequestT')
ResponseT = TypeVar('ResponseT')

DEFAULT_TIMEOUT_SECS = 30 # seconds

_LOGGER = logging.getLogger(__name__)


class UserCodeExecutionException(Exception):
"""Base class for errors related to calling Web APIs."""
Expand All @@ -37,8 +50,10 @@ class UserCodeTimeoutException(UserCodeExecutionException):
"""Extends ``UserCodeExecutionException`` to signal a user code timeout."""


class Caller(metaclass=abc.ABCMeta):
"""Interfaces user custom code intended for API calls."""
class Caller(contextlib.AbstractContextManager, abc.ABC):
"""Interface for user custom code intended for API calls.
For setup and teardown of clients when applicable, implement the
``__enter__`` and ``__exit__`` methods respectively."""
@abc.abstractmethod
def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT:
"""Calls a Web API with the ``RequestT`` and returns a
Expand All @@ -48,18 +63,156 @@ def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT:
"""
pass

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return None


class ShouldBackOff(abc.ABC):
"""
ShouldBackOff provides mechanism to apply adaptive throttling.
"""
pass


class Repeater(abc.ABC):
"""Repeater provides mechanism to repeat requests for a
configurable condition."""
pass

class SetupTeardown(metaclass=abc.ABCMeta):
"""Interfaces user custom code to set up and teardown the API clients.
Called by ``RequestResponseIO`` within its DoFn's setup and teardown
methods.

class CacheReader(abc.ABC):
"""CacheReader provides mechanism to read from the cache."""
pass


class CacheWriter(abc.ABC):
"""CacheWriter provides mechanism to write to the cache."""
pass


class PreCallThrottler(abc.ABC):
"""PreCallThrottler provides a throttle mechanism before sending request."""
pass


class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
beam.PCollection[ResponseT]]):
"""A :class:`RequestResponseIO` transform to read and write to APIs.
Processes an input :class:`~apache_beam.pvalue.PCollection` of requests
by making a call to the API as defined in :class:`Caller`'s `__call__`
and returns a :class:`~apache_beam.pvalue.PCollection` of responses.
"""
def __init__(
self,
caller: [Caller],
timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
should_backoff: Optional[ShouldBackOff] = None,
repeater: Optional[Repeater] = None,
cache_reader: Optional[CacheReader] = None,
cache_writer: Optional[CacheWriter] = None,
throttler: Optional[PreCallThrottler] = None,
):
"""
@abc.abstractmethod
def setup(self) -> None:
"""Called during the DoFn's setup lifecycle method."""
pass
Instantiates a RequestResponseIO transform.
@abc.abstractmethod
def teardown(self) -> None:
"""Called during the DoFn's teardown lifecycle method."""
pass
Args:
caller (~apache_beam.io.requestresponseio.Caller): an implementation of
`Caller` object that makes call to the API.
timeout (float): timeout value in seconds to wait for response from API.
should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff):
(Optional) provides methods for backoff.
repeater (~apache_beam.io.requestresponseio.Repeater): (Optional)
provides methods to repeat requests to API.
cache_reader (~apache_beam.io.requestresponseio.CacheReader): (Optional)
provides methods to read external cache.
cache_writer (~apache_beam.io.requestresponseio.CacheWriter): (Optional)
provides methods to write to external cache.
throttler (~apache_beam.io.requestresponseio.PreCallThrottler):
(Optional) provides methods to pre-throttle a request.
"""
self._caller = caller
self._timeout = timeout
self._should_backoff = should_backoff
self._repeater = repeater
self._cache_reader = cache_reader
self._cache_writer = cache_writer
self._throttler = throttler

def expand(self, requests: PCollection[RequestT]) -> PCollection[ResponseT]:
# TODO(riteshghorse): add Cache and Throttle PTransforms.
return requests | _Call(
caller=self._caller,
timeout=self._timeout,
should_backoff=self._should_backoff,
repeater=self._repeater)


class _Call(beam.PTransform[beam.PCollection[RequestT],
beam.PCollection[ResponseT]]):
"""(Internal-only) PTransform that invokes a remote function on each element
of the input PCollection.
This PTransform uses a `Caller` object to invoke the actual API calls,
and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of
clients when applicable. Additionally, a timeout value is specified to
regulate the duration of each call, defaults to 30 seconds.
Args:
caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable
object that invokes API call.
timeout (float): timeout value in seconds to wait for response from API.
"""
def __init__(
self,
caller: Caller,
timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
should_backoff: Optional[ShouldBackOff] = None,
repeater: Optional[Repeater] = None,
):
"""Initialize the _Call transform.
Args:
caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable
object that invokes API call.
timeout (float): timeout value in seconds to wait for response from API.
should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff):
(Optional) provides methods for backoff.
repeater (~apache_beam.io.requestresponseio.Repeater): (Optional) provides
methods to repeat requests to API.
"""
self._caller = caller
self._timeout = timeout
self._should_backoff = should_backoff
self._repeater = repeater

def expand(
self,
requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
return requests | beam.ParDo(_CallDoFn(self._caller, self._timeout))


class _CallDoFn(beam.DoFn, Generic[RequestT, ResponseT]):
def setup(self):
self._caller.__enter__()

def __init__(self, caller: Caller, timeout: float):
self._caller = caller
self._timeout = timeout

def process(self, request, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(self._caller, request)
try:
yield future.result(timeout=self._timeout)
except concurrent.futures.TimeoutError:
raise UserCodeTimeoutException(
f'Timeout {self._timeout} exceeded '
f'while completing request: {request}')
except RuntimeError:
raise UserCodeExecutionException('could not complete request')

def teardown(self):
self._caller.__exit__(*sys.exc_info())
18 changes: 15 additions & 3 deletions sdks/python/apache_beam/io/requestresponseio_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@

import urllib3

import apache_beam as beam
from apache_beam.io.requestresponseio import Caller
from apache_beam.io.requestresponseio import RequestResponseIO
from apache_beam.io.requestresponseio import UserCodeExecutionException
from apache_beam.io.requestresponseio import UserCodeQuotaException
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline

_HTTP_PATH = '/v1/echo'
_PAYLOAD = base64.b64encode(bytes('payload', 'utf-8'))
Expand Down Expand Up @@ -86,7 +89,6 @@ def __call__(self, request: EchoRequest, *args, **kwargs) -> EchoResponse:
``UserCodeExecutionException``, ``UserCodeTimeoutException``,
or a ``UserCodeQuotaException``.
"""

try:
resp = urllib3.request(
"POST",
Expand All @@ -104,8 +106,8 @@ def __call__(self, request: EchoRequest, *args, **kwargs) -> EchoResponse:

if resp.status == 429: # Too Many Requests
raise UserCodeQuotaException(resp.reason)

raise UserCodeExecutionException(resp.reason)
else:
raise UserCodeExecutionException(resp.status, resp.reason, request)

except urllib3.exceptions.HTTPError as e:
raise UserCodeExecutionException(e)
Expand Down Expand Up @@ -167,6 +169,16 @@ def test_not_found_should_raise(self):
self.assertRaisesRegex(
UserCodeExecutionException, "Not Found", lambda: client(req))

def test_request_response_io(self):
client, options = EchoHTTPCallerTestIT._get_client_and_options()
req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD)
with TestPipeline(is_integration_test=True) as test_pipeline:
output = (
test_pipeline
| 'Create PCollection' >> beam.Create([req])
| 'RRIO Transform' >> RequestResponseIO(client))
self.assertIsNotNone(output)


if __name__ == '__main__':
unittest.main(argv=sys.argv[:1])
88 changes: 88 additions & 0 deletions sdks/python/apache_beam/io/requestresponseio_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import unittest

import apache_beam as beam
from apache_beam.io.requestresponseio import Caller
from apache_beam.io.requestresponseio import RequestResponseIO
from apache_beam.io.requestresponseio import UserCodeExecutionException
from apache_beam.io.requestresponseio import UserCodeTimeoutException
from apache_beam.testing.test_pipeline import TestPipeline


class AckCaller(Caller):
"""AckCaller acknowledges the incoming request by returning a
request with ACK."""
def __enter__(self):
pass

def __call__(self, request: str):
return f"ACK: {request}"

def __exit__(self, exc_type, exc_val, exc_tb):
return None


class CallerWithTimeout(AckCaller):
"""CallerWithTimeout sleeps for 2 seconds before responding.
Used to test timeout in RequestResponseIO."""
def __call__(self, request: str, *args, **kwargs):
time.sleep(2)
return f"ACK: {request}"


class CallerWithRuntimeError(AckCaller):
"""CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO
to raise a UserCodeExecutionException."""
def __call__(self, request: str, *args, **kwargs):
if not request:
raise RuntimeError("Exception expected, not an error.")


class TestCaller(unittest.TestCase):
def test_valid_call(self):
caller = AckCaller()
with TestPipeline() as test_pipeline:
output = (
test_pipeline
| beam.Create(["sample_request"])
| RequestResponseIO(caller=caller))

self.assertIsNotNone(output)

def test_call_timeout(self):
caller = CallerWithTimeout()
with self.assertRaises(UserCodeTimeoutException):
with TestPipeline() as test_pipeline:
_ = (
test_pipeline
| beam.Create(["timeout_request"])
| RequestResponseIO(caller=caller, timeout=1))

def test_call_runtime_error(self):
caller = CallerWithRuntimeError()
with self.assertRaises(UserCodeExecutionException):
with TestPipeline() as test_pipeline:
_ = (
test_pipeline
| beam.Create([""])
| RequestResponseIO(caller=caller))


if __name__ == '__main__':
unittest.main()

0 comments on commit a7976f5

Please sign in to comment.