diff --git a/kombu/transport/pubsub.py b/kombu/transport/pubsub.py index 9c16a2ca4..7ef4da7cb 100644 --- a/kombu/transport/pubsub.py +++ b/kombu/transport/pubsub.py @@ -2,6 +2,8 @@ import os import sys +from dateutil import parser +from threading import Thread from anyjson import dumps, loads from amqp.protocol import queue_declare_ok_t @@ -15,13 +17,40 @@ try: from google.cloud import pubsub_v1 - from google.api_core.exceptions import AlreadyExists + from google.cloud import tasks_v2 + from google.protobuf import timestamp_pb2 + from google.api_core.exceptions import AlreadyExists, DeadlineExceeded except: pubsub_v1 = None logger = get_logger(__name__) +class Worker(Thread): + ''' Worker thread ''' + def __init__(self, client, subscription_path, max_messages, queue): + Thread.__init__(self) + self.subscriber = client + self.subscription_path = subscription_path + self.queue = queue + self.max_messages = max_messages + self.start() + + def run(self): + ''' run ''' + while True: + logger.info("".join(["Pulling messsage using subscription ", + self.subscription_path])) + try: + resp = self.subscriber.pull(self.subscription_path, + self.max_messages, timeout=0.3) + except (ValueError, DeadlineExceeded): + continue + if resp.received_messages: + for msg in resp.received_messages: + self.queue.put(msg, block=True) + + class Message(base.Message): def __init__(self, channel, msg, **kwargs): body, props = self._translate_message(msg) @@ -130,18 +159,10 @@ def _get(self, queue): raise Empty() subscription_path = self._new_queue(queue) if not self.temp_cache[subscription_path].empty(): - return self.temp_cache[subscription_path].get(block=True) - logger.info("".join(["Pulling messsage using subscription ", subscription_path])) - resp = self.subscriber.\ - pull(subscription_path, self.max_messages, return_immediately=True) - if resp.received_messages: - for msg in resp.received_messages: - if self.temp_cache[subscription_path].full(): - break - self.qos.append(msg.message.message_id, - (msg, subscription_path)) - self.temp_cache[subscription_path].put(msg) - return self.temp_cache[subscription_path].get(block=True) + msg = self.temp_cache[subscription_path].get(block=True) + self.qos.append( + msg.message.message_id, (msg, subscription_path)) + return msg raise Empty() def queue_declare(self, queue=None, passive=False, *args, **kwargs): @@ -170,8 +191,8 @@ def queue_bind(self, *args, **kwargs): """ subscription_path = self._new_queue(kwargs.get('queue')) topic_path = self.state.exchanges[kwargs.get('exchange')] - self.temp_cache[subscription_path] =\ - Queue(maxsize=self.max_messages) + queue = Queue(maxsize=self.max_messages) + self.temp_cache[subscription_path] = queue try: self.subscriber.create_subscription( subscription_path, topic_path, @@ -180,6 +201,11 @@ def queue_bind(self, *args, **kwargs): except AlreadyExists: logger.info("".join(["Subscription already exists: ", subscription_path])) pass + if 'celery' in subscription_path: + return + # Start worker + logger.info("".join(["Starting worker: ", subscription_path, " with queue size: ", str(self.max_messages)])) + Worker(self.subscriber, subscription_path, self.max_messages, queue) def exchange_declare(self, exchange='', **kwargs): """Declare a topic in PubSub @@ -217,28 +243,45 @@ def basic_publish(self, message, exchange='', routing_key='', :param exchange: topic name :type body: str """ - eta = loads(message['body'])['eta'] - if eta: - topic = self.delayed_topic - if topic is None: - raise ChannelError( - 'Cannot publish message id {0!r} to None delayed topic'.\ - format(loads(message['body'])['id'])) - topic_path = self.publisher.topic_path( - self.project_id, topic) - message = dumps({ - 'destination_topic': exchange, - 'eta': eta, - 'message': message - }).encode('utf-8') - else: - topic_path =\ + if loads(message['body'])['eta']: + return self._create_cloud_task(exchange, message) + return self._publish(exchange, message) + + def _publish(self, topic, message): + ''' publish the message ''' + topic_path =\ self.publisher.topic_path( self.project_id, exchange) - message = dumps(message).encode('utf-8') + message = dumps(message).encode('utf-8') future = self.publisher.publish( topic_path, message, **kwargs) - return future.result() + return future.result() + + def _create_cloud_task(self, exchange, message): + ''' send task to cloud task ''' + eta = loads(message['body'])['eta'] + task = self._get_task(eta, exchange, message) + return self.cloud_task.create_task(self.cloud_task_queue_path, task) + + def _get_task(self, eta, exchange, message): + parsed_time = parser.parse(eta.strip()) + ts = timestamp_pb2.Timestamp() + ts.FromDatetime(parsed_time) + return { + "http_request": { + "http_method": tasks_v2.enums.HttpMethod.POST, + "headers": {"Content-type": "application/json"}, + "url": self.transport_options.get("CLOUD_FUNCTION_PUBLISHER"), + "body": dumps({ + 'destination_topic': exchange, + 'eta': eta, + 'message': message + }).encode('utf-8'), + }, + "name": self.cloud_task_queue_path + "/tasks/" + "_".join( + [exchange, uuid()]), + "schedule_time": ts, + } @cached_property def publisher(self): @@ -251,21 +294,14 @@ def subscriber(self): return pubsub_v1.SubscriberClient() @cached_property - def ack_deadline_seconds(self): - """Deadline for acknowledgement from the time received. - This is notified to PubSub while subscribing from the client. - """ - return self.transport_options.get('ACK_DEADLINE_SECONDS', 60) - - @cached_property - def delayed_topic(self): - """Delayed topic used to support delay messages in celery""" - return self.transport_options.get('DELAYED_TOPIC', None) + def cloud_task(self): + """ Client connection for cloud task """ + return tasks_v2.CloudTasksClient() @cached_property - def max_messages(self): - """Maximum messages to pull into local cache""" - return self.transport_options.get('MAX_MESSAGES', 10) + def transport_options(self): + """PubSub Transport sepcific configurations""" + return self.connection.client.transport_options @cached_property def project_id(self): @@ -275,9 +311,32 @@ def project_id(self): return self.transport_options.get('PROJECT_ID', '') @cached_property - def transport_options(self): - """PubSub Transport sepcific configurations""" - return self.connection.client.transport_options + def max_messages(self): + """Maximum messages to pull into local cache""" + return self.transport_options.get('MAX_MESSAGES', 10) + + @cached_property + def ack_deadline_seconds(self): + """Deadline for acknowledgement from the time received. + This is notified to PubSub while subscribing from the client. + """ + return self.transport_options.get('ACK_DEADLINE_SECONDS', 60) + + @cached_property + def cloud_task_queue_path(self): + """ Cloud task queue path """ + return self.cloud_task.queue_path( + self.project_id, self.location, self.delayed_queue) + + @cached_property + def location(self): + """ Cloud task queue location """ + return self.transport_options.get('QUEUE_LOCATION', None) + + @cached_property + def delayed_queue(self): + """Delayed topic used to support delay messages in celery""" + return self.transport_options.get('DELAYED_QUEUE', None) class Transport(virtual.Transport): diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt index e4e7d733b..87b86f711 100644 --- a/requirements/test-ci.txt +++ b/requirements/test-ci.txt @@ -3,6 +3,7 @@ coverage>=3.0 coveralls google-api-core==1.23.0 google-cloud-pubsub==1.7.0 +google-cloud-tasks==1.5.0 redis PyYAML msgpack-python>0.2.0 # 0.2.0 dropped 2.5 support