Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Add deadlines to message queue entries #5106

Merged
merged 5 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion golem/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def execute_sql(self, sql, params=None, require_commit=True):


class Database:
SCHEMA_VERSION = 46
SCHEMA_VERSION = 47

def __init__(self, # noqa pylint: disable=too-many-arguments
db: peewee.Database,
Expand Down
18 changes: 18 additions & 0 deletions golem/database/schemas/047_msg_queue_deadline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# pylint: disable=no-member
# pylint: disable=unused-argument
import peewee as pw

from golem.model import default_msg_deadline

SCHEMA_VERSION = 47


def migrate(migrator, database, fake=False, **kwargs):
migrator.add_fields(
'queuedmessage',
deadline=pw.UTCDateTimeField(default=default_msg_deadline())
)


def rollback(migrator, database, fake=False, **kwargs):
migrator.remove_fields('queuedmessage', 'deadline')
16 changes: 15 additions & 1 deletion golem/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from golem.core import common
from golem.core.common import datetime_to_timestamp, default_now
from golem.core.simpleserializer import DictSerializable
from golem.core.variables import MESSAGE_QUEUE_MAX_AGE
from golem.database import GolemSqliteDatabase
from golem.ranking.helper.trust_const import NEUTRAL_TRUST
from golem.ranking import ProviderEfficacy
Expand Down Expand Up @@ -649,17 +650,29 @@ def as_message(self) -> message.base.Message:
return msg


def default_msg_deadline() -> datetime.datetime:
return default_now() + MESSAGE_QUEUE_MAX_AGE


class QueuedMessage(BaseModel):
node = CharField(null=False, index=True)
msg_version = VersionField(null=False)
msg_cls = CharField(null=False)
msg_data = BlobField(null=False)
deadline = UTCDateTimeField(
null=False,
default=default_msg_deadline()
)

class Meta:
database = db

@classmethod
def from_message(cls, node_id: str, msg: message.base.Message):
def from_message(
cls,
node_id: str,
msg: message.base.Message,
deadline: Optional[datetime.datetime] = None):
instance = cls()
instance.node = node_id
instance.msg_cls = '.'.join(
Expand All @@ -669,6 +682,7 @@ def from_message(cls, node_id: str, msg: message.base.Message):
golem_messages.__version__,
)
instance.msg_data = golem_messages.dump(msg, None, None)
instance.deadline = deadline or default_msg_deadline() # type: ignore
etam marked this conversation as resolved.
Show resolved Hide resolved
return instance

def as_message(self) -> message.base.Message:
Expand Down
39 changes: 32 additions & 7 deletions golem/network/transport/msg_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from golem import decorators
from golem import model
from golem.core import variables
from golem.core.common import short_node_id
from golem.core.common import default_now, short_node_id


logger = logging.getLogger(__name__)
Expand All @@ -25,12 +25,17 @@
)


def put(node_id: str, msg: message.base.Message) -> None:
def put(
node_id: str,
msg: message.base.Message,
timeout: typing.Optional[datetime.timedelta] = None
) -> None:
assert not isinstance(msg, FORBIDDEN_CLASSES),\
"Disconnect message shouldn't be in a queue"
logger.debug("saving into queue node_id=%s, msg=%r",
short_node_id(node_id), msg)
db_model = model.QueuedMessage.from_message(node_id, msg)
deadline_utc = (default_now() + timeout) if timeout else None
db_model = model.QueuedMessage.from_message(node_id, msg, deadline_utc)
db_model.save()


Expand All @@ -44,7 +49,17 @@ def get(node_id: str) -> typing.Iterator['message.base.Base']:
).order_by(model.QueuedMessage.created_date).get()
except model.QueuedMessage.DoesNotExist:
return

try:
if db_model.deadline <= default_now():
logger.debug(
'deleting message past its deadline.'
' db_model=%s, deadline=%s',
db_model,
db_model.deadline
)
continue

msg = db_model.as_message()
except msg_exceptions.VersionMismatchError:
logger.info(
Expand Down Expand Up @@ -73,6 +88,8 @@ def get(node_id: str) -> typing.Iterator['message.base.Base']:
def waiting() -> typing.Iterator[str]:
query = model.QueuedMessage.select(
model.QueuedMessage.node,
).where(
model.QueuedMessage.deadline > default_now()
).group_by(model.QueuedMessage.node)
try:
for db_row in query:
Expand All @@ -90,12 +107,20 @@ def waiting() -> typing.Iterator[str]:

@decorators.run_with_db()
def sweep() -> None:
"""Sweep ancient messages"""
"""Sweep messages"""
with READ_LOCK:
oldest_allowed = datetime.datetime.now() \
now = default_now()
count = 0

count += model.QueuedMessage.delete().where(
model.QueuedMessage.deadline <= now
).execute()

oldest_allowed = now \
- variables.MESSAGE_QUEUE_MAX_AGE
count = model.QueuedMessage.delete().where(
count += model.QueuedMessage.delete().where(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use only deadline here

model.QueuedMessage.created_date < oldest_allowed,
).execute()

if count:
logger.info('Sweeped ancient messages from queue. count=%d', count)
logger.info('Sweeped messages from queue. count=%d', count)
2 changes: 2 additions & 0 deletions golem/task/server/resources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import logging
import os
from typing import (
Expand Down Expand Up @@ -286,6 +287,7 @@ def _nonce_shared(self, key_id, result, options):
msg=message.resources.ResourceHandshakeStart(
resource=handshake.hash, options=options.__dict__,
),
timeout=datetime.timedelta(seconds=self.HANDSHAKE_TIMEOUT)
)

def _share_handshake_nonce(self, key_id):
Expand Down
3 changes: 3 additions & 0 deletions golem/task/taskserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=too-many-lines

import asyncio
import datetime
import functools
import itertools
import logging
Expand Down Expand Up @@ -516,6 +517,8 @@ def _request_task(self, theader: dt_tasks.TaskHeader) -> Deferred:
msg_queue.put(
node_id=theader.task_owner.key,
msg=wtct,
timeout=datetime.timedelta(
seconds=deadline_to_timeout(theader.deadline))
)

timer.ProviderTTCDelayTimers.start(wtct.task_id)
Expand Down
52 changes: 50 additions & 2 deletions tests/golem/network/transport/test_msg_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from golem import model
from golem import testutils
from golem.core.common import default_now
from golem.model import default_msg_deadline
from golem.network.transport import msg_queue


Expand All @@ -19,6 +21,7 @@ def setUp(self):
self.node_id = str(uuid.uuid4())
self.msg = tasks_factories.WantToComputeTaskFactory()

@freeze_time()
def test_put(self):
msg_queue.put(self.node_id, self.msg)
row = model.QueuedMessage.get()
Expand All @@ -27,6 +30,7 @@ def test_put(self):
'golem_messages.message.tasks.WantToComputeTask',
)
self.assertEqual(str(row.msg_version), golem_messages.__version__)
self.assertEqual(row.deadline, default_msg_deadline())
row_msg = row.as_message()
self.assertEqual(row_msg.slots(), self.msg.slots())
self.assertIsNone(row_msg.sig)
Expand All @@ -39,6 +43,17 @@ def test_get(self):
self.assertEqual(msg.slots(), self.msg.slots())
self.assertEqual(len(list(msg_queue.get(self.node_id))), 0)

@freeze_time()
def test_get_timeout(self):
timeout = datetime.timedelta(seconds=1)
msg_queue.put(self.node_id, self.msg, timeout)

with freeze_time(default_now() + timeout):
msgs = list(msg_queue.get(self.node_id))

self.assertEqual(len(msgs), 0)
self.assertEqual(len(list(msg_queue.get(self.node_id))), 0)

def test_waiting(self):
node_id2 = str(uuid.uuid4())
node_id3 = str(uuid.uuid4())
Expand Down Expand Up @@ -67,11 +82,31 @@ def test_waiting_programming_error(self, *_args):
waiting = frozenset(msg_queue.waiting())
self.assertEqual(waiting, set())

@freeze_time()
def test_waiting_timeout(self):
timeout = datetime.timedelta(hours=1)
node_id2 = str(uuid.uuid4())
node_id_timeout = str(uuid.uuid4())
msg_queue.put(self.node_id, self.msg)
msg_queue.put(node_id2, self.msg)
msg_queue.put(node_id_timeout, self.msg, timeout)

with freeze_time(default_now() + datetime.timedelta(hours=2)):
waiting = frozenset(msg_queue.waiting())

self.assertEqual(
waiting,
set([
self.node_id,
node_id2
]),
)

def test_sweep(self):
def put_explicit_now():
instance = model.QueuedMessage.from_message(self.node_id, self.msg)
# peewee/sqlite is freezegun resistant
instance.created_date = datetime.datetime.now()
instance.created_date = default_now()
instance.save()
put_explicit_now()
msg_queue.sweep()
Expand All @@ -84,11 +119,24 @@ def put_explicit_now():
model.QueuedMessage.select().count(),
0,
)
now = datetime.datetime.now()
now = default_now()
with freeze_time(now-relativedelta(months=6, seconds=1)):
put_explicit_now()
msg_queue.sweep()
self.assertEqual(
model.QueuedMessage.select().count(),
0,
)

@freeze_time()
def test_sweep_timeout(self):
timeout = datetime.timedelta(seconds=1)
msg_queue.put(self.node_id, self.msg)
msg_queue.put(self.node_id, self.msg, timeout)

msg_queue.sweep()
self.assertEqual(model.QueuedMessage.select().count(), 2)
with freeze_time(default_now() + datetime.timedelta(minutes=1)):
msg_queue.sweep()

self.assertEqual(model.QueuedMessage.select().count(), 1)
3 changes: 2 additions & 1 deletion tests/golem/task/server/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def exception_on_error(error):
if exception:
raise Exception(exception)

mock_queue.assert_called_once_with(node_id=self.key_id, msg=mock.ANY)
mock_queue.assert_called_once_with(
node_id=self.key_id, msg=mock.ANY, timeout=mock.ANY)

def test_start_handshake_nonce_errback(self, *_):
deferred = Deferred()
Expand Down