Skip to content

Commit

Permalink
Adding pause/resume to Pub / Sub consumer. (#4558)
Browse files Browse the repository at this point in the history
Using these (rather then open/close on the subscription Policy)
when the flow control signals the message load is too great.
  • Loading branch information
dhermes authored Dec 8, 2017
1 parent 061011d commit 9ef8334
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 39 deletions.
55 changes: 54 additions & 1 deletion pubsub/google/cloud/pubsub_v1/subscriber/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class Consumer(object):
def __init__(self):
self._request_queue = queue.Queue()
self.stopped = threading.Event()
self._can_consume = threading.Event()
self._put_lock = threading.Lock()
self._consumer_thread = None

Expand Down Expand Up @@ -319,8 +320,10 @@ def _blocking_consume(self, policy):

request_generator = self._request_generator_thread(policy)
response_generator = policy.call_rpc(request_generator)
responses = _pausable_iterator(
response_generator, self._can_consume)
try:
for response in response_generator:
for response in responses:
_LOGGER.debug('Received response:\n%r', response)
policy.on_response(response)

Expand All @@ -339,6 +342,34 @@ def _blocking_consume(self, policy):
self._stop_no_join()
return

def pause(self):
"""Pause the current consumer.
This method is idempotent by design.
This will clear the ``_can_consume`` event which is checked
every time :meth:`_blocking_consume` consumes a response from the
bidirectional streaming pull.
Complement to :meth:`resume`.
"""
_LOGGER.debug('Pausing consumer')
self._can_consume.clear()

def resume(self):
"""Resume the current consumer.
This method is idempotent by design.
This will set the ``_can_consume`` event which is checked
every time :meth:`_blocking_consume` consumes a response from the
bidirectional streaming pull.
Complement to :meth:`pause`.
"""
_LOGGER.debug('Resuming consumer')
self._can_consume.set()

def start_consuming(self, policy):
"""Start consuming the stream.
Expand All @@ -351,6 +382,7 @@ def start_consuming(self, policy):
responses are handled.
"""
self.stopped.clear()
self.resume() # Make sure we aren't paused.
thread = threading.Thread(
name=_BIDIRECTIONAL_CONSUMER_NAME,
target=self._blocking_consume,
Expand All @@ -374,6 +406,7 @@ def _stop_no_join(self):
threading.Thread: The worker ("consumer thread") that is being
stopped.
"""
self.resume() # Make sure we aren't paused.
self.stopped.set()
_LOGGER.debug('Stopping helper thread %s', self._consumer_thread.name)
self.send_request(_helper_threads.STOP)
Expand All @@ -392,3 +425,23 @@ def stop_consuming(self):
"""
thread = self._stop_no_join()
thread.join()


def _pausable_iterator(iterator, can_continue):
"""Converts a standard iterator into one that can be paused.
The ``can_continue`` event can be used by an independent, concurrent
worker to pause and resume the iteration over ``iterator``.
Args:
iterator (Iterator): Any iterator to be iterated over.
can_continue (threading.Event): An event which determines if we
can advance to the next iteration. Will be ``wait()``-ed on
before
Yields:
Any: The items from ``iterator``.
"""
while True:
can_continue.wait()
yield next(iterator)
4 changes: 2 additions & 2 deletions pubsub/google/cloud/pubsub_v1/subscriber/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def drop(self, ack_id, byte_size):
# before restarting.
if self._paused and self._load < self.flow_control.resume_threshold:
self._paused = False
self.open(self._callback)
self._consumer.resume()

def get_initial_request(self, ack_queue=False):
"""Return the initial request.
Expand Down Expand Up @@ -291,7 +291,7 @@ def lease(self, ack_id, byte_size):
# If we do, we need to stop the stream.
if self._load >= 1.0:
self._paused = True
self.close()
self._consumer.pause()

def maintain_leases(self):
"""Maintain all of the leases being managed by the policy.
Expand Down
67 changes: 55 additions & 12 deletions pubsub/tests/unit/pubsub_v1/subscriber/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def test_request_generator_thread():

def test_blocking_consume():
policy = mock.Mock(spec=('call_rpc', 'on_response'))
policy.call_rpc.return_value = (mock.sentinel.A, mock.sentinel.B)
policy.call_rpc.return_value = iter((mock.sentinel.A, mock.sentinel.B))

consumer = _consumer.Consumer()
consumer.resume()

assert consumer._blocking_consume(policy) is None
policy.call_rpc.assert_called_once()
policy.on_response.assert_has_calls(
Expand Down Expand Up @@ -96,11 +98,12 @@ def __call__(self, exception):

def test_blocking_consume_on_exception():
policy = mock.Mock(spec=('call_rpc', 'on_response', 'on_exception'))
policy.call_rpc.return_value = (mock.sentinel.A, mock.sentinel.B)
policy.call_rpc.return_value = iter((mock.sentinel.A, mock.sentinel.B))
exc = TypeError('Bad things!')
policy.on_response.side_effect = exc

consumer = _consumer.Consumer()
consumer.resume()
consumer._consumer_thread = mock.Mock(spec=threading.Thread)
policy.on_exception.side_effect = OnException()

Expand All @@ -114,37 +117,77 @@ def test_blocking_consume_on_exception():
policy.on_exception.assert_called_once_with(exc)


class RaisingResponseGenerator(object):
# NOTE: This is needed because defining `.next` on an **instance**
# rather than the **class** will not be iterable in Python 2.
# This is problematic since a `Mock` just sets members.

def __init__(self, exception):
self.exception = exception
self.done_calls = 0
self.next_calls = 0

def done(self):
self.done_calls += 1
return True

def __next__(self):
self.next_calls += 1
raise self.exception

def next(self):
return self.__next__() # Python 2


def test_blocking_consume_two_exceptions():
policy = mock.Mock(spec=('call_rpc', 'on_exception'))

exc1 = NameError('Oh noes.')
exc2 = ValueError('Something grumble.')
policy.on_exception.side_effect = OnException(acceptable=exc1)

response_generator1 = mock.MagicMock(spec=('__iter__', 'done'))
response_generator1.__iter__.side_effect = exc1
response_generator1.done.return_value = True
response_generator2 = mock.MagicMock(spec=('__iter__', 'done'))
response_generator2.__iter__.side_effect = exc2
response_generator1 = RaisingResponseGenerator(exc1)
response_generator2 = RaisingResponseGenerator(exc2)
policy.call_rpc.side_effect = (response_generator1, response_generator2)

consumer = _consumer.Consumer()
consumer.resume()
consumer._consumer_thread = mock.Mock(spec=threading.Thread)

# Establish that we get responses until we are sent the exiting event.
consumer._blocking_consume(policy)
assert consumer._blocking_consume(policy) is None
assert consumer._consumer_thread is None

# Check mocks.
assert policy.call_rpc.call_count == 2
response_generator1.__iter__.assert_called_once_with()
response_generator1.done.assert_called_once_with()
response_generator2.__iter__.assert_called_once_with()
response_generator2.done.assert_not_called()
assert response_generator1.next_calls == 1
assert response_generator1.done_calls == 1
assert response_generator2.next_calls == 1
assert response_generator2.done_calls == 0
policy.on_exception.assert_has_calls(
[mock.call(exc1), mock.call(exc2)])


@mock.patch.object(_consumer, '_LOGGER')
def test_pause(_LOGGER):
consumer = _consumer.Consumer()
consumer._can_consume.set()

assert consumer.pause() is None
assert not consumer._can_consume.is_set()
_LOGGER.debug.assert_called_once_with('Pausing consumer')


@mock.patch.object(_consumer, '_LOGGER')
def test_resume(_LOGGER):
consumer = _consumer.Consumer()
consumer._can_consume.clear()

assert consumer.resume() is None
assert consumer._can_consume.is_set()
_LOGGER.debug.assert_called_once_with('Resuming consumer')


def test_start_consuming():
creds = mock.Mock(spec=credentials.Credentials)
client = subscriber.Client(credentials=creds)
Expand Down
58 changes: 34 additions & 24 deletions pubsub/tests/unit/pubsub_v1/subscriber/test_policy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,13 @@ def test_ack_no_time():
def test_ack_paused():
policy = create_policy()
policy._paused = True
policy._consumer.stopped.clear()
with mock.patch.object(policy, 'open') as open_:
consumer = policy._consumer

with mock.patch.object(consumer, 'resume') as resume:
policy.ack('ack_id_string')
open_.assert_called()
resume.assert_called_once_with()

assert policy._paused is False
assert 'ack_id_string' in policy._ack_on_resume


Expand Down Expand Up @@ -157,33 +160,38 @@ def test_drop_below_threshold():
"""
policy = create_policy()
policy.managed_ack_ids.add('ack_id_string')
policy._bytes = 20
num_bytes = 20
policy._bytes = num_bytes
policy._paused = True
with mock.patch.object(policy, 'open') as open_:
policy.drop(ack_id='ack_id_string', byte_size=20)
open_.assert_called_once_with(policy._callback)
consumer = policy._consumer

with mock.patch.object(consumer, 'resume') as resume:
policy.drop(ack_id='ack_id_string', byte_size=num_bytes)
resume.assert_called_once_with()

assert policy._paused is False


def test_load():
flow_control = types.FlowControl(max_messages=10, max_bytes=1000)
policy = create_policy(flow_control=flow_control)

# This should mean that our messages count is at 10%, and our bytes
# are at 15%; the ._load property should return the higher (0.15).
policy.lease(ack_id='one', byte_size=150)
assert policy._load == 0.15

# After this message is added, the messages should be higher at 20%
# (versus 16% for bytes).
policy.lease(ack_id='two', byte_size=10)
assert policy._load == 0.2

# Returning a number above 100% is fine.
with mock.patch.object(policy, 'close') as close:
consumer = policy._consumer

with mock.patch.object(consumer, 'pause') as pause:
# This should mean that our messages count is at 10%, and our bytes
# are at 15%; the ._load property should return the higher (0.15).
policy.lease(ack_id='one', byte_size=150)
assert policy._load == 0.15
pause.assert_not_called()
# After this message is added, the messages should be higher at 20%
# (versus 16% for bytes).
policy.lease(ack_id='two', byte_size=10)
assert policy._load == 0.2
pause.assert_not_called()
# Returning a number above 100% is fine.
policy.lease(ack_id='three', byte_size=1000)
assert policy._load == 1.16
close.assert_called_once_with()
pause.assert_called_once_with()


def test_modify_ack_deadline():
Expand Down Expand Up @@ -251,11 +259,13 @@ def test_lease():
def test_lease_above_threshold():
flow_control = types.FlowControl(max_messages=2)
policy = create_policy(flow_control=flow_control)
with mock.patch.object(policy, 'close') as close:
consumer = policy._consumer

with mock.patch.object(consumer, 'pause') as pause:
policy.lease(ack_id='first_ack_id', byte_size=20)
assert close.call_count == 0
pause.assert_not_called()
policy.lease(ack_id='second_ack_id', byte_size=25)
close.assert_called_once_with()
pause.assert_called_once_with()


def test_nack():
Expand Down

0 comments on commit 9ef8334

Please sign in to comment.