diff --git a/pubsub/google/cloud/pubsub_v1/publisher/batch/base.py b/pubsub/google/cloud/pubsub_v1/publisher/batch/base.py index 03705dce9c14..dae0dafb9fd5 100644 --- a/pubsub/google/cloud/pubsub_v1/publisher/batch/base.py +++ b/pubsub/google/cloud/pubsub_v1/publisher/batch/base.py @@ -49,6 +49,16 @@ def __len__(self): """Return the number of messages currently in the batch.""" return len(self.messages) + @staticmethod + @abc.abstractmethod + def make_lock(): + """Return a lock in the chosen concurrency model. + + Returns: + ContextManager: A newly created lock. + """ + raise NotImplementedError + @property @abc.abstractmethod def messages(self): diff --git a/pubsub/google/cloud/pubsub_v1/publisher/batch/thread.py b/pubsub/google/cloud/pubsub_v1/publisher/batch/thread.py index b339865220b9..12b9790c6b80 100644 --- a/pubsub/google/cloud/pubsub_v1/publisher/batch/thread.py +++ b/pubsub/google/cloud/pubsub_v1/publisher/batch/thread.py @@ -90,6 +90,15 @@ def __init__(self, client, topic, settings, autocommit=True): ) self._thread.start() + @staticmethod + def make_lock(): + """Return a threading lock. + + Returns: + _thread.Lock: A newly created lock. + """ + return threading.Lock() + @property def client(self): """~.pubsub_v1.client.PublisherClient: A publisher client.""" diff --git a/pubsub/google/cloud/pubsub_v1/publisher/client.py b/pubsub/google/cloud/pubsub_v1/publisher/client.py index 3b1a7a2a7d2d..d2faedad1d8a 100644 --- a/pubsub/google/cloud/pubsub_v1/publisher/client.py +++ b/pubsub/google/cloud/pubsub_v1/publisher/client.py @@ -17,7 +17,6 @@ import copy import os import pkg_resources -import threading import grpc import six @@ -44,16 +43,21 @@ class Client(object): Args: batch_settings (~google.cloud.pubsub_v1.types.BatchSettings): The settings for batch publishing. - batch_class (class): A class that describes how to handle + batch_class (Optional[type]): A class that describes how to handle batches. You may subclass the :class:`.pubsub_v1.publisher.batch.base.BaseBatch` class in order to define your own batcher. This is primarily provided to allow use of different concurrency models; the default - is based on :class:`threading.Thread`. + is based on :class:`threading.Thread`. This class should also have + a class method (or static method) that takes no arguments and + produces a lock that can be used as a context manager. kwargs (dict): Any additional arguments provided are sent as keyword arguments to the underlying :class:`~.gapic.pubsub.v1.publisher_client.PublisherClient`. Generally, you should not need to set additional keyword arguments. + Before being passed along to the GAPIC constructor, a channel may + be added if ``credentials`` are passed explicitly or if the + Pub / Sub emulator is detected as running. """ def __init__(self, batch_settings=(), batch_class=thread.Batch, **kwargs): # Sanity check: Is our goal to use the emulator? @@ -86,7 +90,7 @@ def __init__(self, batch_settings=(), batch_class=thread.Batch, **kwargs): # The batches on the publisher client are responsible for holding # messages. One batch exists for each topic. self._batch_class = batch_class - self._batch_lock = threading.Lock() + self._batch_lock = batch_class.make_lock() self._batches = {} @property diff --git a/pubsub/tests/unit/pubsub_v1/publisher/batch/test_thread.py b/pubsub/tests/unit/pubsub_v1/publisher/batch/test_thread.py index 903ae90794a4..2c8852576308 100644 --- a/pubsub/tests/unit/pubsub_v1/publisher/batch/test_thread.py +++ b/pubsub/tests/unit/pubsub_v1/publisher/batch/test_thread.py @@ -72,6 +72,13 @@ def test_init_infinite_latency(): assert batch._thread is None +@mock.patch.object(threading, 'Lock') +def test_make_lock(Lock): + lock = Batch.make_lock() + assert lock is Lock.return_value + Lock.assert_called_once_with() + + def test_client(): client = create_client() settings = types.BatchSettings() diff --git a/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py b/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py index a519ddc645fd..55a4990761d4 100644 --- a/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py +++ b/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py @@ -16,8 +16,8 @@ import os from google.auth import credentials -import mock +import mock import pytest from google.cloud.pubsub_v1.gapic import publisher_client