Skip to content

Commit

Permalink
Merge pull request #4 from cldcvr/development
Browse files Browse the repository at this point in the history
feat: direct call to create cloud task for delayed messages
  • Loading branch information
sauravcld committed Dec 16, 2020
2 parents d1d0f09 + 7094160 commit 9f30f3f
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 49 deletions.
157 changes: 108 additions & 49 deletions kombu/transport/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import os
import sys
from dateutil import parser
from threading import Thread

from anyjson import dumps, loads
from amqp.protocol import queue_declare_ok_t
Expand All @@ -15,13 +17,40 @@

try:
from google.cloud import pubsub_v1
from google.api_core.exceptions import AlreadyExists
from google.cloud import tasks_v2
from google.protobuf import timestamp_pb2
from google.api_core.exceptions import AlreadyExists, DeadlineExceeded
except:
pubsub_v1 = None

logger = get_logger(__name__)


class Worker(Thread):
''' Worker thread '''
def __init__(self, client, subscription_path, max_messages, queue):
Thread.__init__(self)
self.subscriber = client
self.subscription_path = subscription_path
self.queue = queue
self.max_messages = max_messages
self.start()

def run(self):
''' run '''
while True:
logger.info("".join(["Pulling messsage using subscription ",
self.subscription_path]))
try:
resp = self.subscriber.pull(self.subscription_path,
self.max_messages, timeout=0.3)
except (ValueError, DeadlineExceeded):
continue
if resp.received_messages:
for msg in resp.received_messages:
self.queue.put(msg, block=True)


class Message(base.Message):
def __init__(self, channel, msg, **kwargs):
body, props = self._translate_message(msg)
Expand Down Expand Up @@ -130,18 +159,10 @@ def _get(self, queue):
raise Empty()
subscription_path = self._new_queue(queue)
if not self.temp_cache[subscription_path].empty():
return self.temp_cache[subscription_path].get(block=True)
logger.info("".join(["Pulling messsage using subscription ", subscription_path]))
resp = self.subscriber.\
pull(subscription_path, self.max_messages, return_immediately=True)
if resp.received_messages:
for msg in resp.received_messages:
if self.temp_cache[subscription_path].full():
break
self.qos.append(msg.message.message_id,
(msg, subscription_path))
self.temp_cache[subscription_path].put(msg)
return self.temp_cache[subscription_path].get(block=True)
msg = self.temp_cache[subscription_path].get(block=True)
self.qos.append(
msg.message.message_id, (msg, subscription_path))
return msg
raise Empty()

def queue_declare(self, queue=None, passive=False, *args, **kwargs):
Expand Down Expand Up @@ -170,8 +191,8 @@ def queue_bind(self, *args, **kwargs):
"""
subscription_path = self._new_queue(kwargs.get('queue'))
topic_path = self.state.exchanges[kwargs.get('exchange')]
self.temp_cache[subscription_path] =\
Queue(maxsize=self.max_messages)
queue = Queue(maxsize=self.max_messages)
self.temp_cache[subscription_path] = queue
try:
self.subscriber.create_subscription(
subscription_path, topic_path,
Expand All @@ -180,6 +201,11 @@ def queue_bind(self, *args, **kwargs):
except AlreadyExists:
logger.info("".join(["Subscription already exists: ", subscription_path]))
pass
if 'celery' in subscription_path:
return
# Start worker
logger.info("".join(["Starting worker: ", subscription_path, " with queue size: ", str(self.max_messages)]))
Worker(self.subscriber, subscription_path, self.max_messages, queue)

def exchange_declare(self, exchange='', **kwargs):
"""Declare a topic in PubSub
Expand Down Expand Up @@ -217,28 +243,45 @@ def basic_publish(self, message, exchange='', routing_key='',
:param exchange: topic name
:type body: str
"""
eta = loads(message['body'])['eta']
if eta:
topic = self.delayed_topic
if topic is None:
raise ChannelError(
'Cannot publish message id {0!r} to None delayed topic'.\
format(loads(message['body'])['id']))
topic_path = self.publisher.topic_path(
self.project_id, topic)
message = dumps({
'destination_topic': exchange,
'eta': eta,
'message': message
}).encode('utf-8')
else:
topic_path =\
if loads(message['body'])['eta']:
return self._create_cloud_task(exchange, message)
return self._publish(exchange, message)

def _publish(self, topic, message):
''' publish the message '''
topic_path =\
self.publisher.topic_path(
self.project_id, exchange)
message = dumps(message).encode('utf-8')
message = dumps(message).encode('utf-8')
future = self.publisher.publish(
topic_path, message, **kwargs)
return future.result()
return future.result()

def _create_cloud_task(self, exchange, message):
''' send task to cloud task '''
eta = loads(message['body'])['eta']
task = self._get_task(eta, exchange, message)
return self.cloud_task.create_task(self.cloud_task_queue_path, task)

def _get_task(self, eta, exchange, message):
parsed_time = parser.parse(eta.strip())
ts = timestamp_pb2.Timestamp()
ts.FromDatetime(parsed_time)
return {
"http_request": {
"http_method": tasks_v2.enums.HttpMethod.POST,
"headers": {"Content-type": "application/json"},
"url": self.transport_options.get("CLOUD_FUNCTION_PUBLISHER"),
"body": dumps({
'destination_topic': exchange,
'eta': eta,
'message': message
}).encode('utf-8'),
},
"name": self.cloud_task_queue_path + "/tasks/" + "_".join(
[exchange, uuid()]),
"schedule_time": ts,
}

@cached_property
def publisher(self):
Expand All @@ -251,21 +294,14 @@ def subscriber(self):
return pubsub_v1.SubscriberClient()

@cached_property
def ack_deadline_seconds(self):
"""Deadline for acknowledgement from the time received.
This is notified to PubSub while subscribing from the client.
"""
return self.transport_options.get('ACK_DEADLINE_SECONDS', 60)

@cached_property
def delayed_topic(self):
"""Delayed topic used to support delay messages in celery"""
return self.transport_options.get('DELAYED_TOPIC', None)
def cloud_task(self):
""" Client connection for cloud task """
return tasks_v2.CloudTasksClient()

@cached_property
def max_messages(self):
"""Maximum messages to pull into local cache"""
return self.transport_options.get('MAX_MESSAGES', 10)
def transport_options(self):
"""PubSub Transport sepcific configurations"""
return self.connection.client.transport_options

@cached_property
def project_id(self):
Expand All @@ -275,9 +311,32 @@ def project_id(self):
return self.transport_options.get('PROJECT_ID', '')

@cached_property
def transport_options(self):
"""PubSub Transport sepcific configurations"""
return self.connection.client.transport_options
def max_messages(self):
"""Maximum messages to pull into local cache"""
return self.transport_options.get('MAX_MESSAGES', 10)

@cached_property
def ack_deadline_seconds(self):
"""Deadline for acknowledgement from the time received.
This is notified to PubSub while subscribing from the client.
"""
return self.transport_options.get('ACK_DEADLINE_SECONDS', 60)

@cached_property
def cloud_task_queue_path(self):
""" Cloud task queue path """
return self.cloud_task.queue_path(
self.project_id, self.location, self.delayed_queue)

@cached_property
def location(self):
""" Cloud task queue location """
return self.transport_options.get('QUEUE_LOCATION', None)

@cached_property
def delayed_queue(self):
"""Delayed topic used to support delay messages in celery"""
return self.transport_options.get('DELAYED_QUEUE', None)


class Transport(virtual.Transport):
Expand Down
1 change: 1 addition & 0 deletions requirements/test-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ coverage>=3.0
coveralls
google-api-core==1.23.0
google-cloud-pubsub==1.7.0
google-cloud-tasks==1.5.0
redis
PyYAML
msgpack-python>0.2.0 # 0.2.0 dropped 2.5 support

0 comments on commit 9f30f3f

Please sign in to comment.