diff --git a/golem/client.py b/golem/client.py index e58ef94023..44f2eb499d 100644 --- a/golem/client.py +++ b/golem/client.py @@ -11,7 +11,6 @@ from typing import ( Any, Dict, - Hashable, Iterable, List, Optional, @@ -61,6 +60,7 @@ from golem.monitor.model.nodemetadatamodel import NodeMetadataModel from golem.monitor.monitor import SystemMonitor from golem.monitorconfig import MONITOR_CONFIG +from golem.network import broadcast from golem.network import nodeskeeper from golem.network.concent.client import ConcentClientService from golem.network.concent.filetransfers import ConcentFiletransferService @@ -263,6 +263,7 @@ def get_wamp_rpc_mapping(self): from golem.environments.minperformancemultiplier import \ MinPerformanceMultiplier from golem.network.concent import soft_switch as concent_soft_switch + from golem.rpc.api import broadcast_ as api_broadcast from golem.rpc.api import ethereum_ as api_ethereum from golem.task import rpc as task_rpc from golem.apps import rpc as apps_rpc @@ -282,6 +283,7 @@ def get_wamp_rpc_mapping(self): task_rpc_provider, app_rpc_provider, api_ethereum.ETSProvider(self.transaction_system), + api_broadcast, ) mapping = {} for rpc_provider in providers: @@ -1743,6 +1745,7 @@ def _run(self) -> None: jobs = ( nodeskeeper.sweep, msg_queue.sweep, + broadcast.sweep, lambda: logger.info( "Time marker. time(): %s now(): %s, utcnow(): %s, delta: %s", time.time(), diff --git a/golem/config/environments/mainnet.py b/golem/config/environments/mainnet.py index 0d2830f20c..cb920e39d8 100644 --- a/golem/config/environments/mainnet.py +++ b/golem/config/environments/mainnet.py @@ -54,6 +54,8 @@ def __init__(self): # P2P +BROADCAST_PUBKEY = b'\xab\xab;\xb0\x89\x10\r\xf8Hs\xd7\x91\xcc\x13\xdb\x0b9tw\x80\xd4t?\xdc\x9dS.\x9at\xe3X\xbcBK\x1c\xef\xdb3\xab}z\xad\xde"ZW\xa9T\xdeN\xb6\xc7P\x0e\xa9\x7fv\x1a\xec\xcbN\x07R\x10' # noqa pylint: disable=line-too-long + P2P_SEEDS = [ ('seeds.golem.network', 40102), ('0.seeds.golem.network', 40102), diff --git a/golem/config/environments/testnet.py b/golem/config/environments/testnet.py index e5f46d5683..674c82b445 100644 --- a/golem/config/environments/testnet.py +++ b/golem/config/environments/testnet.py @@ -55,6 +55,8 @@ def __init__(self): # P2P +BROADCAST_PUBKEY = b'\xbe\x0e\xb0@\xad\xad~\xd7\xe3\xca\x96*k\x7f\x0b*\x96++\xb0{\x95+n~\xfdF\xc8\x88\xff\x06\x93cr\xb3\xcb@\xc8Y\xd5n\x98|\xec\x90$\xf2E\xf9\xbbyh:\x99"\xaf\xa2-\xc9os:\xb6\x88' # noqa pylint: disable=line-too-long + P2P_SEEDS = [ ('94.23.57.58', 40102), ('94.23.57.58', 40104), diff --git a/golem/core/databuffer.py b/golem/core/databuffer.py index 0c7a24dc8a..9126ffb5f5 100644 --- a/golem/core/databuffer.py +++ b/golem/core/databuffer.py @@ -100,7 +100,7 @@ def read_len_prefixed_bytes(self): """ ret_bytes = None - if (self.data_size() > LONG_STANDARD_SIZE and + if (self.data_size() >= LONG_STANDARD_SIZE and self.data_size() >= (self.peek_ulong() + LONG_STANDARD_SIZE)): num_bytes = self.read_ulong() ret_bytes = self.read_bytes(num_bytes) diff --git a/golem/database/database.py b/golem/database/database.py index 4b6bff5a53..e1d4730205 100644 --- a/golem/database/database.py +++ b/golem/database/database.py @@ -61,7 +61,7 @@ def execute_sql(self, sql, params=None, require_commit=True): class Database: - SCHEMA_VERSION = 47 + SCHEMA_VERSION = 48 def __init__(self, # noqa pylint: disable=too-many-arguments db: peewee.Database, diff --git a/golem/database/schemas/048_create_broadcast.py b/golem/database/schemas/048_create_broadcast.py new file mode 100644 index 0000000000..c65fd7c349 --- /dev/null +++ b/golem/database/schemas/048_create_broadcast.py @@ -0,0 +1,25 @@ +# pylint: disable=no-member +# pylint: disable=unused-argument +import datetime + +import peewee as pw + +SCHEMA_VERSION = 48 + + +def migrate(migrator, database, fake=False, **kwargs): + @migrator.create_model # pylint: disable=unused-variable + class Broadcast(pw.Model): + timestamp = pw.IntegerField() + broadcast_type = pw.IntegerField() + signature = pw.BlobField() + data = pw.BlobField() + created_date = pw.DateTimeField(default=datetime.datetime.now) + modified_date = pw.DateTimeField(default=datetime.datetime.now) + + class Meta: + db_table = "broadcast" + + +def rollback(migrator, database, fake=False, **kwargs): + migrator.remove_model("broadcast") diff --git a/golem/model.py b/golem/model.py index f72536075c..52d96a185a 100644 --- a/golem/model.py +++ b/golem/model.py @@ -1,8 +1,10 @@ import datetime import enum +import hashlib import inspect import json import pickle +import struct import sys import time from typing import Any, Dict, Optional @@ -10,6 +12,7 @@ from eth_utils import decode_hex, encode_hex from ethereum.utils import denoms import golem_messages +from golem_messages import cryptography from golem_messages import datastructures as msg_dt from golem_messages import message from golem_messages.datastructures import p2p as dt_p2p, masking @@ -199,7 +202,7 @@ def python_value(self, value): class EnumField(EnumFieldBase, IntegerField): """ Database field that maps enum type to integer.""" - def __init__(self, enum_type, *args, **kwargs): + def __init__(self, *args, enum_type=None, **kwargs): super(EnumField, self).__init__(*args, **kwargs) self.enum_type = enum_type @@ -877,6 +880,105 @@ def __repr__(self): ) +class Broadcast(BaseModel): + class TYPE(enum.IntEnum): + Version = enum.auto() + HEADER_FORMAT = '!HQ' + HEADER_LENGTH = struct.calcsize(HEADER_FORMAT) + SIGNATURE_LENGTH = 65 + timestamp = IntegerField() + broadcast_type = EnumField(enum_type=TYPE) + signature = BlobField() + data = BlobField() + + def __repr__(self): + return ( + f"<{self.__class__.__qualname__}" + f" {self.timestamp} {self.broadcast_type.name}>" + ) + + @classmethod + def create_and_sign(cls, private_key, broadcast_type, data) -> 'Broadcast': + bc = cls() + bc.timestamp = int(time.time()) + bc.broadcast_type = broadcast_type + bc.data = data + bc.sign(private_key=private_key) + bc.save(force_insert=True) + return bc + + def process(self) -> bool: + if Broadcast.select().where( + Broadcast.broadcast_type == self.broadcast_type, + Broadcast.timestamp == self.timestamp, + ).exists(): + return False + if self.broadcast_type is self.TYPE.Version: + from golem.network.p2p.peersession import compare_version + compare_version(self.data.decode('utf-8', 'replace')) + self.save(force_insert=True) + return True + + def header_to_bytes(self) -> bytes: + return struct.pack( + self.HEADER_FORMAT, + self.broadcast_type, + self.timestamp, + ) + + def header_from_bytes(self, b: bytes) -> None: + try: + broadcast_type, self.timestamp = struct.unpack( + self.HEADER_FORMAT, + b, + ) + self.broadcast_type = self.TYPE(broadcast_type) # type: ignore + except (ValueError, struct.error): + from golem.network import broadcast + raise broadcast.BroadcastError('Invalid header') + + @classmethod + def from_bytes(cls, b: bytes) -> 'Broadcast': + # Remember to verify signature of this broadcast if it's been loaded + # from untrusted source + if len(b) < cls.HEADER_LENGTH + cls.SIGNATURE_LENGTH: + from golem.network import broadcast + raise broadcast.BroadcastError( + 'Invalid broadcast: too short' + f' ({len(b)} < {cls.HEADER_LENGTH + cls.SIGNATURE_LENGTH})', + ) + bc = cls() + bc.header_from_bytes(b[:cls.HEADER_LENGTH]) + bc.signature = b[ + cls.HEADER_LENGTH:cls.HEADER_LENGTH+cls.SIGNATURE_LENGTH + ] + bc.data = b[cls.HEADER_LENGTH+cls.SIGNATURE_LENGTH:] + return bc + + def to_bytes(self) -> bytes: + return self.header_to_bytes() + self.signature + self.data + + def get_hash(self) -> bytes: + sha = hashlib.sha1() + sha.update(self.header_to_bytes()) + sha.update(self.data) + return sha.digest() + + def sign(self, private_key: bytes) -> None: + assert self.signature is None + self.signature = cryptography.ecdsa_sign( + privkey=private_key, + msghash=self.get_hash(), + ) + + def verify_signature(self, public_key: bytes) -> None: + cryptography.ecdsa_verify( + pubkey=public_key, + signature=self.signature, + message=self.get_hash(), + ) + + def collect_db_models(module: str = __name__): return inspect.getmembers( sys.modules[module], diff --git a/golem/network/broadcast.py b/golem/network/broadcast.py new file mode 100644 index 0000000000..8d8b0a814e --- /dev/null +++ b/golem/network/broadcast.py @@ -0,0 +1,74 @@ +import logging +import typing + +import peewee + +from golem import decorators +from golem import model +from golem.config import active +from golem.core.databuffer import DataBuffer + + +logger = logging.getLogger(__name__) + + +class BroadcastError(Exception): + pass + + +def list_from_bytes(b: bytes) -> typing.List[model.Broadcast]: + db = DataBuffer() + db.append_bytes(b) + result = [] + for cnt, broadcast_binary in enumerate(db.get_len_prefixed_bytes()): + if cnt >= 10: + break + try: + b = model.Broadcast.from_bytes(broadcast_binary) + b.verify_signature(public_key=active.BROADCAST_PUBKEY) + result.append(b) + except BroadcastError as e: + logger.debug( + 'Invalid broadcast received: %s. b=%r', + e, + broadcast_binary, + ) + except Exception: # pylint: disable=broad-except + logger.debug( + 'Invalid broadcast received: %r', + broadcast_binary, + exc_info=True, + ) + return result + + +def list_to_bytes(l: typing.List[model.Broadcast]) -> bytes: + db = DataBuffer() + for broadcast in l: + assert isinstance(broadcast, model.Broadcast) + db.append_len_prefixed_bytes(broadcast.to_bytes()) + return db.read_all() + + +def prepare_handshake() -> typing.List[model.Broadcast]: + query = model.Broadcast.select().where( + model.Broadcast.broadcast_type == model.Broadcast.TYPE.Version, + ) + bl = [] + if query.exists(): + bl.append(query.order_by('-timestamp')[0]) + logger.debug('Prepared handshake: %s', bl) + return bl + + +@decorators.run_with_db() +def sweep() -> None: + max_timestamp = model.Broadcast.select( + peewee.fn.MAX(model.Broadcast.timestamp), + ).scalar() + count = model.Broadcast.delete().where( + model.Broadcast.broadcast_type == model.Broadcast.TYPE.Version, + model.Broadcast.timestamp < max_timestamp, + ).execute() + if count: + logger.info('Sweeped broadcasts. count=%d', count) diff --git a/golem/network/p2p/p2pservice.py b/golem/network/p2p/p2pservice.py index 53642490c8..bb137f7eac 100644 --- a/golem/network/p2p/p2pservice.py +++ b/golem/network/p2p/p2pservice.py @@ -72,7 +72,7 @@ def __init__( """ network = tcpnetwork.TCPNetwork( ProtocolFactory( - tcpnetwork.SafeProtocol, + tcpnetwork.BroadcastProtocol, self, SessionFactory(PeerSession) ), @@ -882,13 +882,17 @@ def _send_get_tasks(self): for p in list(self.peers.values()): p.send_get_tasks() - def __connection_established(self, session, conn_id: str): - peer_conn = session.conn.transport.getPeer() + def __connection_established( + self, + protocol: tcpnetwork.BroadcastProtocol, + conn_id: str, + ): + peer_conn = protocol.transport.getPeer() ip_address = peer_conn.host port = peer_conn.port - session.conn_id = conn_id - self._mark_connected(conn_id, session.address, session.port) + protocol.conn_id = conn_id + self._mark_connected(conn_id, ip_address, port) logger.debug("Connection to peer established. %s: %s, conn_id %s", ip_address, port, conn_id) diff --git a/golem/network/p2p/peersession.py b/golem/network/p2p/peersession.py index f017ac7bf0..17323669c4 100644 --- a/golem/network/p2p/peersession.py +++ b/golem/network/p2p/peersession.py @@ -273,11 +273,6 @@ def _react_to_hello(self, msg): logger.error("Received unexpected Hello message, ignoring") return - # Check if sender is a seed/bootstrap node - port = getattr(msg, 'port', None) - if (self.address, port) in self.p2p_service.seeds: - compare_version(getattr(msg, 'client_ver', None)) - if not self.conn.opened: return diff --git a/golem/network/transport/network.py b/golem/network/transport/network.py index a7e7f9a3ab..9f33f98af1 100644 --- a/golem/network/transport/network.py +++ b/golem/network/transport/network.py @@ -1,9 +1,76 @@ import abc +import logging +import typing + + +import transitions from twisted.internet.protocol import Factory, Protocol, connectionDone from .tcpnetwork_helpers import TCPConnectInfo, TCPListenInfo, TCPListeningInfo +logger = logging.getLogger(__name__) + + +class ExtendedMachine(transitions.Machine): + def add_transition_callback( # pylint: disable=too-many-arguments + self, + trigger: str, + source: str, + dest: str, + callback_trigger: str, # 'before', 'after' or 'prepare' + callback_func, + ): + for transition in self.get_transitions(trigger, source, dest): + transition.add_callback( + trigger=callback_trigger, + func=callback_func, + ) + + def copy_transitions( # pylint: disable=too-many-arguments + self, + from_trigger: str, + from_source: str, + from_dest: str, + to_trigger: str, + to_source: str, + to_dest: str, + ): + for transition in self.get_transitions( + from_trigger, + from_source, + from_dest, + ): + self.add_transition( + trigger=to_trigger, + source=to_source, + dest=to_dest, + before=transition.before, + after=transition.after, + prepare=transition.prepare, + ) + # conditions are ignored, implement if needed + + def move_transitions( # pylint: disable=too-many-arguments + self, + from_trigger: str, + from_source: str, + from_dest: str, + to_trigger: str, + to_source: str, + to_dest: str, + ): + self.copy_transitions( + from_trigger, + from_source, + from_dest, + to_trigger, + to_source, + to_dest, + ) + self.remove_transition(from_trigger, from_source, from_dest) + + class Network(abc.ABC): @abc.abstractmethod def connect(self, connect_info: TCPConnectInfo) -> None: @@ -19,81 +86,92 @@ def stop_listening(self, listening_info: TCPListeningInfo): class SessionFactory(object): + CONN_TYPE: typing.Optional[int] = None + def __init__(self, session_class): self.session_class = session_class - def get_session(self, conn): - return self.session_class(conn) - - -class IncomingSessionFactoryWrapper(object): - def __init__(self, session_factory): - self.session_factory = session_factory - - def get_session(self, conn): - session = self.session_factory.get_session(conn) - session.conn_type = Session.CONN_TYPE_SERVER - return session - + @classmethod + def from_factory(cls, factory: 'SessionFactory') -> 'SessionFactory': + return cls(session_class=factory.session_class) -class OutgoingSessionFactoryWrapper(object): - def __init__(self, session_factory): - self.session_factory = session_factory - - def get_session(self, conn): - session = self.session_factory.get_session(conn) - session.conn_type = Session.CONN_TYPE_CLIENT + def get_session(self, conn) -> 'Session': + session = self.session_class(conn) + session.conn_type = self.CONN_TYPE return session class ProtocolFactory(Factory): + SESSION_WRAPPER: typing.Optional[typing.Type['SessionFactory']] = None + def __init__(self, protocol_class, server=None, session_factory=None): self.protocol_class = protocol_class self.server = server + if self.SESSION_WRAPPER is not None: + session_factory = self.SESSION_WRAPPER.from_factory(session_factory) self.session_factory = session_factory - def buildProtocol(self, addr): - protocol = self.protocol_class(self.server) - protocol.set_session_factory(self.session_factory) - return protocol - - -class IncomingProtocolFactoryWrapper(Factory): - def __init__(self, protocol_factory): - self.protocol_factory = protocol_factory - self.session_factory = IncomingSessionFactoryWrapper( - protocol_factory.session_factory) - - def buildProtocol(self, addr): - protocol = self.protocol_factory.buildProtocol(addr) - protocol.set_session_factory(self.session_factory) - return protocol - - -class OutgoingProtocolFactoryWrapper(Factory): - def __init__(self, protocol_factory): - self.protocol_factory = protocol_factory - self.session_factory = OutgoingSessionFactoryWrapper( - protocol_factory.session_factory) + @classmethod + def from_factory(cls, factory: 'ProtocolFactory') -> 'ProtocolFactory': + return cls( + protocol_class=factory.protocol_class, + server=factory.server, + session_factory=factory.session_factory, + ) def buildProtocol(self, addr): - protocol = self.protocol_factory.buildProtocol(addr) - protocol.set_session_factory(self.session_factory) - return protocol + return self.protocol_class(self.session_factory, server=self.server) class SessionProtocol(Protocol): - def __init__(self): + def __init__(self, session_factory, **_kwargs): """Connection-oriented basic protocol for twisted""" - self.session_factory = None - self.session = None - - def set_session_factory(self, session_factory): - """ :param SessionFactory session_factory: """ self.session_factory = session_factory + self.session: typing.Optional[Session] = None + self.machine = ExtendedMachine( + self, + states=[ + 'initial', + 'connected', + 'disconnected', + ], + initial='initial', + auto_transitions=False, + ) + self.machine.add_transition( + 'connectionMadeTransition', + 'initial', + 'connected', + after=self.create_session, + ) + self.machine.add_transition( + 'connectionLostTransition', + '*', + 'disconnected', + ) + + def after_disconnection(reason): # pylint: disable=unused-argument + self.session.dropped() + delattr(self, 'session') + self.machine.add_transition_callback( + 'connectionLostTransition', + 'connected', + 'disconnected', + 'after', + after_disconnection, + ) - # Protocol function def connectionMade(self): + super().connectionMade() + # map twisted Protocol event into transition + self.connectionMadeTransition() # pylint: disable=no-member + + def connectionLost(self, reason=connectionDone): + super().connectionLost(reason=reason) + # map twisted Protocol event into transition + self.connectionLostTransition(reason=reason) # noqa pylint: disable=no-member + + def create_session(self) -> None: """Called when new connection is successfully opened""" # If the underlying transport is TCP, enable TCP keepalive. @@ -104,12 +182,8 @@ def connectionMade(self): except AttributeError: pass - Protocol.connectionMade(self) self.session = self.session_factory.get_session(self) - def connectionLost(self, reason=connectionDone): - del self.session - class Session(object, metaclass=abc.ABCMeta): CONN_TYPE_CLIENT = 1 @@ -129,3 +203,19 @@ def interpret(self, msg): @abc.abstractmethod def disconnect(self, reason): raise NotImplementedError + + +class IncomingSessionFactory(SessionFactory): + CONN_TYPE = Session.CONN_TYPE_SERVER + + +class OutgoingSessionFactory(SessionFactory): + CONN_TYPE = Session.CONN_TYPE_CLIENT + + +class IncomingProtocolFactory(ProtocolFactory): + SESSION_WRAPPER = IncomingSessionFactory + + +class OutgoingProtocolFactory(ProtocolFactory): + SESSION_WRAPPER = OutgoingSessionFactory diff --git a/golem/network/transport/tcpnetwork.py b/golem/network/transport/tcpnetwork.py index 74aa674f57..9e3b9ed3bf 100644 --- a/golem/network/transport/tcpnetwork.py +++ b/golem/network/transport/tcpnetwork.py @@ -1,6 +1,7 @@ import logging import struct import time +import typing import golem_messages from golem_messages import message @@ -8,13 +9,17 @@ from twisted.internet.endpoints import TCP4ServerEndpoint, \ TCP4ClientEndpoint, TCP6ServerEndpoint, TCP6ClientEndpoint, \ HostnameEndpoint -from twisted.internet.protocol import connectionDone from golem.core.databuffer import DataBuffer from golem.core.hostaddress import get_host_addresses +from golem.network import broadcast from golem.network.transport.limiter import CallRateLimiter -from .network import Network, SessionProtocol, IncomingProtocolFactoryWrapper, \ - OutgoingProtocolFactoryWrapper +from .network import ( + IncomingProtocolFactory, + Network, + OutgoingProtocolFactory, + SessionProtocol, +) from .spamprotector import SpamProtector # Import helpers to this namespace @@ -46,9 +51,9 @@ def __init__(self, protocol_factory, use_ipv6=False, timeout=5, """ from twisted.internet import reactor self.reactor = reactor - self.incoming_protocol_factory = IncomingProtocolFactoryWrapper( + self.incoming_protocol_factory = IncomingProtocolFactory.from_factory( protocol_factory) - self.outgoing_protocol_factory = OutgoingProtocolFactoryWrapper( + self.outgoing_protocol_factory = OutgoingProtocolFactory.from_factory( protocol_factory) self.use_ipv6 = use_ipv6 self.timeout = timeout @@ -165,13 +170,13 @@ def __try_to_connect_to_address(self, connect_info: TCPConnectInfo): connect_info) @staticmethod - def __connection_established(conn, established_callback, + def __connection_established(protocol, established_callback, connect_info: TCPConnectInfo): - pp = conn.transport.getPeer() + pp = protocol.transport.getPeer() logger.debug("Connection established %r %r", pp.host, pp.port) TCPNetwork.__call_established_callback( established_callback, - conn.session, + protocol, connect_info, ) @@ -182,11 +187,11 @@ def __connection_failure(err_desc, failure_callback, TCPNetwork.__call_failure_callback(failure_callback, connect_info) @staticmethod - def __connection_to_address_established(conn, + def __connection_to_address_established(protocol, connect_info: TCPConnectInfo): TCPNetwork.__call_established_callback( connect_info.established_callback, - conn, + protocol, ) def __connection_to_address_failure(self, connect_info: TCPConnectInfo): @@ -243,7 +248,17 @@ def __call_established_callback(established_callback, result, *args, **kwargs): if established_callback is None: return - established_callback(result, *args, **kwargs) + try: + established_callback(result, *args, **kwargs) + except Exception: # pylint: disable=broad-except + logger.error( + "Problem calling established callback: %s(*%s, **%s)", + established_callback, + args, + kwargs, + exc_info=True, + ) + raise @staticmethod def __stop_listening_success(result, callback): @@ -269,9 +284,8 @@ class BasicProtocol(SessionProtocol): serialization """ - def __init__(self): - super().__init__() - self.opened = False + def __init__(self, session_factory, **_kwargs): + super().__init__(session_factory) self.db = DataBuffer() self.spam_protector = SpamProtector() @@ -307,12 +321,11 @@ def close(self): """ self.transport.loseConnection() - # Protocol functions - def connectionMade(self): - """Called when new connection is successfully opened""" - SessionProtocol.connectionMade(self) - self.opened = True + @property + def opened(self): + return self.is_connected() # pylint: disable=no-member + # Protocol functions def dataReceived(self, data): """Called when additional chunk of data is received from another peer""" @@ -325,16 +338,9 @@ def dataReceived(self, data): self._interpret(data) - def connectionLost(self, reason=connectionDone): - """Called when connection is lost (for whatever reason)""" - self.opened = False - if self.session: - self.session.dropped() - - SessionProtocol.connectionLost(self, reason) - # Protected functions - def _prepare_msg_to_send(self, msg): + @classmethod + def _prepare_msg_to_send(cls, msg): ser_msg = golem_messages.dump(msg, None, None) db = DataBuffer() @@ -351,7 +357,8 @@ def _interpret(self, data): for m in mess: self.session.interpret(m) - def _load_message(self, data): + @classmethod + def _load_message(cls, data): msg = golem_messages.load(data, None, None) logger.debug( 'BasicProtocol._load_message(): received %r', @@ -420,27 +427,27 @@ class ServerProtocol(BasicProtocol): """ Basic protocol connected to server instance """ - def __init__(self, server): + def __init__(self, session_factory, server, **_kwargs): """ :param Server server: server instance :return None: """ - BasicProtocol.__init__(self) + BasicProtocol.__init__(self, session_factory) self.server = server + self.machine.add_transition_callback( + 'connectionMadeTransition', 'initial', 'connected', + 'after', + lambda: self.server.new_connection(self.session), + ) # Protocol functions - def connectionMade(self): - """Called when new connection is successfully opened""" - BasicProtocol.connectionMade(self) - self.server.new_connection(self.session) - def _can_receive(self) -> bool: if not self.opened: logger.warning("Protocol is closed") return False if not self.session and self.server: - self.opened = False + self.connectionLostTransition() # pylint: disable=no-member logger.warning('Peer for connection is None') return False @@ -479,3 +486,94 @@ def _load_message(self, data): msg, ) return msg + + +class BroadcastProtocol(SafeProtocol): + """Send and expect broadcast message before any other communication""" + + def __init__(self, session_factory, server, **_kwargs): + super().__init__(session_factory, server) + self.conn_id: typing.Optional[str] = None + self.machine.add_state('handshaking') + self.machine.add_transition( + 'connectionMadeTransition', + 'initial', + 'handshaking', + after=self.sendHandshake, + ) + self.machine.move_transitions( + from_trigger='connectionMadeTransition', + from_source='initial', + from_dest='connected', + to_trigger='handshakeFinished', + to_source='handshaking', + to_dest='connected', + ) + self.machine.copy_transitions( + from_trigger='connectionLostTransition', + from_source='connected', + from_dest='disconnected', + to_trigger='connectionLostTransition', + to_source='handshaking', + to_dest='disconnected', + ) + + def create_session(self): + super().create_session() + self.session.conn_id = self.conn_id + + def sendHandshake(self) -> bool: + handshake_bytes = broadcast.list_to_bytes(broadcast.prepare_handshake()) + db = DataBuffer() + db.append_len_prefixed_bytes(handshake_bytes) + + self.transport.getHandle() + self.transport.write(db.read_all()) + return True + + def dataReceived(self, data: bytes) -> None: + if self.is_connected(): # pylint: disable=no-member + return super().dataReceived(data) + if self.is_handshaking(): # pylint: disable=no-member + try: + return self.dataReceivedHandshake(data) + except broadcast.BroadcastError: + logger.debug( + 'Invalid broadcast received. peer=%s, data=%s', + self.transport.getPeer(), + data, + ) + return None + logger.debug( + '%(module_name)s.%(class_name)s.dataReceived(%(data)r)' + ' Protocol not ready.' + ' Current state: %(machine)s', + { + 'class_name': self.__class__.__qualname__, + 'module_name': __name__, + 'data': data, + 'machine': self.state, # pylint: disable=no-member + }, + ) + return None + + def dataReceivedHandshake(self, data: bytes) -> None: + logger.debug('handshake data received %sb', len(data)) + self.db.append_bytes(data) + b = self.db.read_len_prefixed_bytes() + if b is None: + return None + broadcasts_l = broadcast.list_from_bytes(b) + for bc in broadcasts_l: + if not bc.process(): + logger.debug( + 'Broadcast rejected: %s, peer: %s', + bc, + self.transport.getPeer(), + ) + self.handshakeFinished() # pylint: disable=no-member + logger.debug( + "Sucesfuly finished handshake with %s", + self.transport.getPeer(), + ) + return None diff --git a/golem/network/transport/tcpnetwork_helpers.py b/golem/network/transport/tcpnetwork_helpers.py index b71bd1c1fb..f6fe82e5f9 100644 --- a/golem/network/transport/tcpnetwork_helpers.py +++ b/golem/network/transport/tcpnetwork_helpers.py @@ -226,9 +226,12 @@ def __init__(self, if final_failure_callback else None) def __str__(self): + def get_func(cbk): + return cbk.func if cbk is not None else None + return ("TCP connection information: addresses {}, " "callback {}, errback {}, final_errback {}").format( self.socket_addresses, - self.established_callback.func, - self.failure_callback.func, - self.final_failure_callback.func) + get_func(self.established_callback), + get_func(self.failure_callback), + get_func(self.final_failure_callback)) diff --git a/golem/rpc/api/broadcast_.py b/golem/rpc/api/broadcast_.py new file mode 100644 index 0000000000..82f944c6aa --- /dev/null +++ b/golem/rpc/api/broadcast_.py @@ -0,0 +1,62 @@ +from golem import model +from golem.config import active +from golem.core import common +from golem.rpc import utils as rpc_utils + + +@rpc_utils.expose('broadcast.hash') +def hash_( + timestamp: int, + broadcast_type: int, + data_hex: str, +) -> str: + """Generate hash of a broadcast that should be signed by client + before pushing + """ + type_ = model.Broadcast.TYPE(int(broadcast_type)) + data = bytes.fromhex(data_hex) + bc = model.Broadcast( + broadcast_type=type_, + timestamp=int(timestamp), + data=data, + ) + return bc.get_hash().hex() + + +@rpc_utils.expose('broadcast.push') +def push( + timestamp: int, + broadcast_type: int, + data_hex: str, + signature_hex: str, +): + """Push signed broadcast into the p2p network + """ + data = bytes.fromhex(data_hex) + signature = bytes.fromhex(signature_hex) + bc = model.Broadcast( + broadcast_type=model.Broadcast.TYPE(int(broadcast_type)), + timestamp=int(timestamp), + data=data, + signature=signature, + ) + bc.verify_signature(public_key=active.BROADCAST_PUBKEY) + if not bc.process(): + raise RuntimeError("Broadcast rejected") + + +@rpc_utils.expose('broadcast.list') +def list_(): + """Return all known broadcasts from local DB + """ + return [ + { + 'timestamp': bc.timestamp, + 'broadcast_type': bc.broadcast_type.value, + 'broadcast_type_name': bc.broadcast_type.name, + 'data_hex': bc.data.hex(), + 'created_date': common.datetime_to_timestamp_utc(bc.created_date), + } + for bc + in model.Broadcast.select().order_by('created_date') + ] diff --git a/golem/task/server/queue_.py b/golem/task/server/queue_.py index 0148a569ad..a8b93e2e95 100644 --- a/golem/task/server/queue_.py +++ b/golem/task/server/queue_.py @@ -11,6 +11,7 @@ if typing.TYPE_CHECKING: # pylint: disable=unused-import + from golem.network.transport import tcpnetwork from golem.task import taskkeeper from golem.task.tasksession import TaskSession @@ -110,10 +111,13 @@ def sweep_sessions(self): def msg_queue_connection_established( self, - session: 'TaskSession', + protocol: 'tcpnetwork.SafeProtocol', conn_id, node_id, ): + session = protocol.session + if typing.TYPE_CHECKING: + assert isinstance(session, TaskSession) try: if self.sessions[node_id] is not None: # There is a session already established diff --git a/requirements.txt b/requirements.txt index 0f98495e2e..05b614ad96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -157,6 +157,7 @@ tabulate==0.8.2 token-bucket==0.2.0 toml==0.10.0 toolz==0.9.0 +transitions==0.7.2 treq==17.8.0 Twisted==19.7.0 txaio==18.8.1 diff --git a/requirements_to-freeze.txt b/requirements_to-freeze.txt index ce654f89ef..beef649c79 100644 --- a/requirements_to-freeze.txt +++ b/requirements_to-freeze.txt @@ -59,6 +59,7 @@ setuptools>=36.0.1,<39.0.0 six==1.12.0 tabulate token_bucket==0.2.0 +transitions==0.7.2 Twisted==19.7.0 txaio==18.8.1 ujson==1.35 diff --git a/scripts/broadcast-helper.py b/scripts/broadcast-helper.py new file mode 100755 index 0000000000..99458e73a9 --- /dev/null +++ b/scripts/broadcast-helper.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +from getpass import getpass +import os.path +import time + +import appdirs +from golem_messages import cryptography + +import golem +from golem import model + + +def ask(prompt, default): + answer = input(f"{prompt} [{default}]: ") + return answer or default + + +def main(): + datadir = ask( + 'datadir', + os.path.join(appdirs.user_data_dir('golem'), 'default'), + ) + port = ask('RPC port', '61000') + cli_invocation = f"golemcli -d {datadir} -p {port} debug rpc " + timestamp = int(time.time()) + broadcast_type = model.Broadcast.TYPE( + ask('broadcast type', model.Broadcast.TYPE.Version.value), + ) + print('selected', broadcast_type) + data = ask('data', golem.__version__).encode('ascii') + print( + cli_invocation + + f"broadcast.hash {timestamp} {broadcast_type.value} {data.hex()}", + ) + hash_ = bytes.fromhex(input('hash hex: ')) + private_key = bytes.fromhex(getpass('Private key (hex): ')) + signature = cryptography.ecdsa_sign(private_key, hash_) + print( + cli_invocation + + f"broadcast.push {timestamp} {broadcast_type.value} {data.hex()}" + f" {signature.hex()}", + ) + print( + cli_invocation + + "broadcast.list", + ) + + +if __name__ == '__main__': + main() diff --git a/tests/golem/network/p2p/test_p2pservice.py b/tests/golem/network/p2p/test_p2pservice.py index 2c7bd43875..9fd43e8079 100644 --- a/tests/golem/network/p2p/test_p2pservice.py +++ b/tests/golem/network/p2p/test_p2pservice.py @@ -321,11 +321,11 @@ def test_seeds_round_robin(self, m_connect): self.assertGreater(len(self.service.seeds), 0) self.service.connect_to_known_hosts = True self.service.connect_to_seeds() - self.assertEquals(m_connect.call_count, 1) + self.assertEqual(m_connect.call_count, 1) m_connect.reset_mock() m_connect.side_effect = RuntimeError('ConnectionProblem') self.service.connect_to_seeds() - self.assertEquals(m_connect.call_count, len(self.service.seeds)) + self.assertEqual(m_connect.call_count, len(self.service.seeds)) def test_want_to_start_task_session(self): self.service.task_server = mock.MagicMock() @@ -336,7 +336,7 @@ def test_want_to_start_task_session(self): self.service.task_server.task_connections_helper \ .is_new_conn_request = mock.Mock(side_effect=lambda *_: True) - def true_method(*args) -> bool: + def true_method(*_args) -> bool: return True def gen_uuid(): @@ -473,19 +473,21 @@ def test_connect_success(self, connection_established, createSocket): socket.fileno = mock.Mock(return_value=0) socket.getsockopt = mock.Mock(return_value=None) socket.connect_ex = mock.Mock(return_value=EISCONN) + socket.recv = mock.Mock(return_value=b'') addr = SocketAddress('127.0.0.1', 40102) self.service.connect(addr) time.sleep(0.1) - assert connection_established.called - assert connection_established.call_args[0][0] is not None - assert connection_established.call_args[1]['conn_id'] is not None + connection_established.assert_called_once_with( + mock.ANY, + conn_id=mock.ANY, + ) @mock.patch('twisted.internet.tcp.BaseClient.createInternetSocket', side_effect=Exception('something has failed')) @mock.patch('golem.network.p2p.p2pservice.' 'P2PService._P2PService__connection_failure') - def test_connect_failure(self, connection_failure, createSocket): + def test_connect_failure(self, connection_failure, _createSocket): self.service.resume() addr = SocketAddress('127.0.0.1', 40102) self.service.connect(addr) diff --git a/tests/golem/network/p2p/test_peersession.py b/tests/golem/network/p2p/test_peersession.py index 84bbe430ef..b181ac0c9d 100644 --- a/tests/golem/network/p2p/test_peersession.py +++ b/tests/golem/network/p2p/test_peersession.py @@ -1,18 +1,13 @@ # pylint: disable=protected-access,no-member import copy -import ipaddress import random -import sys import uuid from unittest import TestCase from unittest.mock import patch, Mock, MagicMock, ANY -import semantic_version -from golem_messages import factories as msg_factories from golem_messages import message from golem_messages.factories.datastructures import p2p as dt_p2p_factory -from pydispatch import dispatcher import golem from golem import clientconfigdescriptor @@ -240,86 +235,6 @@ def test_handshake_client_randval(self, send_mock): message.base.RandVal(rand_val=-1)) self.assertFalse(self.peer_session.verified) - def test_react_to_hello_new_version(self): - listener = MagicMock() - dispatcher.connect(listener, signal='golem.p2p') - self.peer_session.p2p_service.seeds = { - (host, random.randint(0, 65535)) - for host in - ipaddress.ip_network('192.0.2.0/29').hosts() - } - - peer_info = MagicMock() - peer_info.key = ( - 'What is human warfare but just this;' - 'an effort to make the laws of God and nature' - 'take sides with one party.' - ) - msg_kwargs = { - 'port': random.randint(0, 65535), - 'client_ver': None, - 'node_info': peer_info, - 'proto_id': random.randint(0, sys.maxsize), - 'metadata': None, - 'solve_challenge': None, - 'challenge': None, - 'difficulty': None, - } - - # Test not seed - msg = message.base.Hello(**msg_kwargs) - self.peer_session._react_to_hello(msg) - self.assertEqual(listener.call_count, 0) - listener.reset_mock() - - # Choose one seed - chosen_seed = random.choice(tuple(self.peer_session.p2p_service.seeds)) - msg_kwargs['port'] = chosen_seed[1] - self.peer_session.address = chosen_seed[0] - - # Test with seed, default version (0) - msg = message.base.Hello(**msg_kwargs) - self.peer_session._react_to_hello(msg) - self.assertEqual(listener.call_count, 0) - listener.reset_mock() - - # Test with seed, newer version - version = semantic_version.Version(golem.__version__).next_patch() - msg_kwargs['client_ver'] = str(version) - msg = message.base.Hello(**msg_kwargs) - self.peer_session._react_to_hello(msg) - listener.assert_called_once_with( - signal='golem.p2p', - event='new_version', - version=version, - sender=ANY, - ) - listener.reset_mock() - - def test_react_to_hello_new_version_partial(self): - "Sometimes we'll get partial version from bootstrap node" - listener = MagicMock() - dispatcher.connect(listener, signal='golem.p2p') - self.peer_session.p2p_service.seeds = { - (host, random.randint(0, 65535)) - for host in - ipaddress.ip_network('192.0.2.0/29').hosts() - } - version = semantic_version.Version(golem.__version__) - chosen_seed = random.choice(tuple(self.peer_session.p2p_service.seeds)) - self.peer_session.address = chosen_seed[0] - msg = msg_factories.base.HelloFactory( - client_ver=f"{version.major}.{version.next_minor().minor}", - port=chosen_seed[1], - ) - self.peer_session._react_to_hello(msg) - listener.assert_called_once_with( - signal='golem.p2p', - event='new_version', - version=version.next_minor(), # Full version here - sender=ANY, - ) - def test_disconnect(self): conn = MagicMock() peer_session = PeerSession(conn) diff --git a/tests/golem/network/test_broadcast.py b/tests/golem/network/test_broadcast.py new file mode 100644 index 0000000000..62e38d7d2d --- /dev/null +++ b/tests/golem/network/test_broadcast.py @@ -0,0 +1,41 @@ +from freezegun import freeze_time +import peewee + +from golem import model +from golem import testutils +from golem.network import broadcast + + +class SweepTestCase(testutils.DatabaseFixture): + def setUp(self): + super().setUp() + self.privkey = b"/#\x99s\xff\x97Y\xf1\xa1\x03\xd4N4\x14F\x94\xbc\x87\xacr\\\x9f\xf6\x96'\xa5\x18\xeb\x19\xc04-" # noqa pylint: disable=line-too-long + + @classmethod + def test_basic(cls): + broadcast.sweep() + + def test_single(self): + model.Broadcast.create_and_sign(self.privkey, 1, b'1.3.3.7') + broadcast.sweep() + self.assertEqual( + model.Broadcast.select(peewee.fn.Count()).scalar(), + 1, + ) + + def test_two(self): + with freeze_time("2018-01-01 00:00:00") as frozen_time: + # override a bug in freezegun that passes + # frozen_time as first argument even to methods (before self) + self.frozen_two(frozen_time) + + def frozen_two(self, frozen_time): + model.Broadcast.create_and_sign(self.privkey, 1, b'1.3.3.7') + # bug in freezegun puts frozen_time as first argument event in methods + frozen_time.tick() # pylint: disable=no-member + model.Broadcast.create_and_sign(self.privkey, 1, b'3.1.3.3.7') + broadcast.sweep() + self.assertEqual( + model.Broadcast.select(peewee.fn.Count()).scalar(), + 1, + ) diff --git a/tests/golem/network/transport/test_network.py b/tests/golem/network/transport/test_network.py index 79df8fd144..8c7fcd880f 100644 --- a/tests/golem/network/transport/test_network.py +++ b/tests/golem/network/transport/test_network.py @@ -1,15 +1,14 @@ from contextlib import contextmanager import logging -import os import time import unittest from golem_messages import message import golem_messages.cryptography +import transitions from golem.network.transport.network import ProtocolFactory, SessionFactory, \ SessionProtocol -from golem.network.transport import session from golem.network.transport.tcpnetwork import TCPNetwork, TCPListenInfo, \ TCPListeningInfo, TCPConnectInfo, \ SocketAddress, BasicProtocol, ServerProtocol, SafeProtocol @@ -278,8 +277,17 @@ def write(self, msg): class TestProtocols(unittest.TestCase): + @classmethod + def get_protocols(cls): + session_factory = SessionFactory(ASession) + return [ + BasicProtocol(session_factory, ), + ServerProtocol(session_factory, Server()), + SafeProtocol(session_factory, Server()), + ] + def test_init(self): - prt = [BasicProtocol(), ServerProtocol(Server()), SafeProtocol(Server())] + prt = self.get_protocols() for p in prt: from twisted.internet.protocol import Protocol self.assertTrue(isinstance(p, Protocol)) @@ -289,19 +297,15 @@ def test_init(self): self.assertIsNotNone(p.server) def test_close(self): - prt = [BasicProtocol(), ServerProtocol(Server()), SafeProtocol(Server())] - for p in prt: + for p in self.get_protocols(): p.transport = Transport() self.assertFalse(p.transport.lose_connection_called) p.close() self.assertTrue(p.transport.lose_connection_called) def test_connection_made(self): - prt = [BasicProtocol(), ServerProtocol(Server()), SafeProtocol(Server())] - for p in prt: + for p in self.get_protocols(): p.transport = Transport() - session_factory = SessionFactory(ASession) - p.set_session_factory(session_factory) self.assertFalse(p.opened) p.connectionMade() self.assertTrue(p.opened) @@ -311,29 +315,22 @@ def test_connection_made(self): self.assertNotIn('session', p.__dict__) def test_connection_lost(self): - prt = [BasicProtocol(), ServerProtocol(Server()), SafeProtocol(Server())] - for p in prt: + for p in self.get_protocols(): p.transport = Transport() - session_factory = SessionFactory(ASession) - p.set_session_factory(session_factory) self.assertIsNone(p.session) p.connectionLost() self.assertFalse(p.opened) - p.connectionMade() - self.assertTrue(p.opened) - self.assertIsNotNone(p.session) - self.assertFalse(p.session.dropped_called) - p.connectionLost() - self.assertFalse(p.opened) - self.assertNotIn('session', p.__dict__) + with self.assertRaises(transitions.MachineError): + # Can't trigger event connectionMadeTransition + # from state disconnected! + p.connectionMade() class TestBasicProtocol(unittest.TestCase): def test_send_and_receive_message(self): - p = BasicProtocol() - p.transport = Transport() session_factory = SessionFactory(ASession) - p.set_session_factory(session_factory) + p = BasicProtocol(session_factory) + p.transport = Transport() self.assertFalse(p.send_message("123")) msg = message.base.Hello() self.assertFalse(p.send_message(msg)) @@ -358,9 +355,8 @@ def test_send_and_receive_message(self): class TestServerProtocol(unittest.TestCase): def test_connection_made(self): - p = ServerProtocol(Server()) session_factory = SessionFactory(ASession) - p.set_session_factory(session_factory) + p = ServerProtocol(session_factory, Server()) p.connectionMade() self.assertEqual(len(p.server.sessions), 1) p.connectionLost() @@ -369,10 +365,9 @@ def test_connection_made(self): class TestSaferProtocol(unittest.TestCase): def test_send_and_receive_message(self): - p = SafeProtocol(Server()) - p.transport = Transport() session_factory = SessionFactory(ASession) - p.set_session_factory(session_factory) + p = SafeProtocol(session_factory, Server()) + p.transport = Transport() self.assertFalse(p.send_message("123")) msg = message.base.Hello() self.assertIsNone(msg.sig) diff --git a/tests/golem/network/transport/test_tcpnetwork.py b/tests/golem/network/transport/test_tcpnetwork.py index 20dd5c3de7..0bbcd5ef7f 100644 --- a/tests/golem/network/transport/test_tcpnetwork.py +++ b/tests/golem/network/transport/test_tcpnetwork.py @@ -3,9 +3,9 @@ import unittest from unittest import mock -import golem_messages import semantic_version from freezegun import freeze_time +import golem_messages from golem_messages import exceptions as msg_exceptions from golem_messages import message from golem_messages import factories as msg_factories @@ -32,10 +32,13 @@ class TestConformance(unittest.TestCase, testutils.PEP8MixIn): class TestBasicProtocol(LogTestCase): def setUp(self): - self.protocol = tcpnetwork.BasicProtocol() - self.protocol.session = mock.MagicMock() - self.protocol.session.my_private_key = None - self.protocol.session.theirs_public_key = None + session_mock = mock.MagicMock() + session_mock.my_private_key = None + session_mock.theirs_private_key = None + self.protocol = tcpnetwork.BasicProtocol( + mock.MagicMock(return_value=session_mock), + ) + self.protocol.transport = mock.MagicMock() def test_init(self): @@ -44,11 +47,13 @@ def test_init(self): @mock.patch('golem_messages.load') def test_dataReceived(self, load_mock): data = b"abc" - self.assertIsNone(self.protocol.dataReceived(data)) - self.protocol.opened = True - self.assertIsNone(self.protocol.dataReceived(data)) + # can_receive() returns False + self.protocol.dataReceived(data) + + self.protocol.connectionMade() + self.protocol.dataReceived(data) self.protocol.db.clear_buffer() - self.assertEqual(load_mock.call_count, 0) + load_mock.assert_not_called() m = message.base.Disconnect(reason=None) data = m.serialize() @@ -61,8 +66,8 @@ def test_dataReceived(self, load_mock): 'golem.network.transport.tcpnetwork.BasicProtocol._load_message' ) def test_dataReceived_long(self, load_mock): + self.protocol.connectionMade() data = bytes([0xff] * (MAX_MESSAGE_SIZE + 1)) - self.protocol.opened = True self.assertIsNone(self.protocol.dataReceived(data)) self.assertEqual(load_mock.call_count, 0) @@ -101,11 +106,14 @@ def test_golem_messages_failed(self, check_mock, close_mock, send_mock): class SafeProtocolTestCase(unittest.TestCase): def setUp(self): - self.protocol = SafeProtocol(MagicMock()) - self.protocol.opened = True - self.protocol.session = mock.MagicMock() - self.protocol.session.my_private_key = None - self.protocol.session.theirs_public_key = None + session_mock = mock.MagicMock() + session_mock.my_private_key = None + session_mock.theirs_private_key = None + self.protocol = SafeProtocol( + MagicMock(return_value=session_mock), + MagicMock(), + ) + self.protocol.connectionMade() @mock.patch('golem_messages.load') def test_drop_set_task(self, load_mock): @@ -136,7 +144,7 @@ def test_zone_index(self): address = "fe80::3%eth0" port = 1111 sa = SocketAddress(address, port) - assert sa.address == base_address + self.assertEqual(sa.address, base_address) assert sa.port == port address = "fe80::3%1" @@ -152,7 +160,7 @@ def test_zone_index(self): assert sa.address == base_address def test_is_proper_address(self): - assert SocketAddress.is_proper_address("127.0.0.1", 1020) + self.assertTrue(SocketAddress.is_proper_address("127.0.0.1", 1020)) assert not SocketAddress.is_proper_address("127.0.0.1", 0) assert not SocketAddress.is_proper_address("127.0.0.1", "ABC") assert not SocketAddress.is_proper_address("AB?*@()F*)A", 1020) diff --git a/tests/golem/rpc/api/test_broadcast.py b/tests/golem/rpc/api/test_broadcast.py new file mode 100644 index 0000000000..2c4a999617 --- /dev/null +++ b/tests/golem/rpc/api/test_broadcast.py @@ -0,0 +1,116 @@ +from unittest import mock + +from freezegun import freeze_time +import golem_messages.exceptions + +from golem import model +from golem import testutils +from golem.config import active +from golem.rpc.api import broadcast_ as api_broadcast + +PRIVATE_KEY = b"\x91M7\x06\x85\xd1\x15\xc7\x14\t\xe9\xca+\xef\xce\x15\xdf\xc5\xb6\x93]\xdc\xd0p\x0f\x18'\x92=3\n/" # noqa pylint: disable=line-too-long +PUBLIC_KEY = b'\xb7\xdap\xa8\xbb\xb49\xe8\xf1\xcd\xf7IL\xe1c)J\x88L\xca\xf9\xf1\x17\x02><\xad^]L\xb6\x06U\xae\xc6\x97\xc8Y\xfd\xeb\x98\x80\xef\x94\xe3p^\xe0\xa2\xddD\xeb\xa7\xd6\x8c\xab\xcd\x90\xe7\x97+H\xd0\x0f' # noqa pylint: disable=line-too-long + + +class BroadcastTestBase(testutils.DatabaseFixture): + def setUp(self): + super().setUp() + self.timestamp = 1582813813 + self.broadcast_type = 1 + self.data_hex = '302e3233' + self.hash_ = '20cd626884c83455ab59fbfbfe2944fa6e187c20' + self.signature_hex = '7cf206f88696700f1a6f87c4a99a4bf11e8526a860f2a9d32345a3c1f9a95d985e1878ef60495e8deca4032d5622ffa02a3f059248084d07aee4dd4effead64500' # noqa pylint: disable=line-too-long + self.query = model.Broadcast.select().where( + model.Broadcast.timestamp == self.timestamp, + model.Broadcast.broadcast_type == self.broadcast_type, + model.Broadcast.data == b'0.23', + model.Broadcast.signature == b'|\xf2\x06\xf8\x86\x96p\x0f\x1ao\x87\xc4\xa9\x9aK\xf1\x1e\x85&\xa8`\xf2\xa9\xd3#E\xa3\xc1\xf9\xa9]\x98^\x18x\xef`I^\x8d\xec\xa4\x03-V"\xff\xa0*?\x05\x92H\x08M\x07\xae\xe4\xddN\xff\xea\xd6E\x00', # noqa pylint: disable=line-too-long + ) + + +class HashTest(BroadcastTestBase): + def test_basic(self): + result = api_broadcast.hash_( + timestamp=self.timestamp, + broadcast_type=self.broadcast_type, + data_hex=self.data_hex, + ) + self.assertIsInstance(result, str) + self.assertEqual(result, self.hash_) + + def test_string_arguments(self): + # Useful when using with `golemcli debug rpc` + result = api_broadcast.hash_( + timestamp=str(self.timestamp), + broadcast_type=str(self.broadcast_type), + data_hex=self.data_hex, + ) + self.assertIsInstance(result, str) + self.assertEqual(result, self.hash_) + + +@mock.patch.object(active, 'BROADCAST_PUBKEY', PUBLIC_KEY) +class PushTest(BroadcastTestBase): + def test_basic(self): + api_broadcast.push( + timestamp=self.timestamp, + broadcast_type=self.broadcast_type, + data_hex=self.data_hex, + signature_hex=self.signature_hex, + ) + self.assertTrue(self.query.exists()) + + def test_string_arguments(self): + # Useful when using with `golemcli debug rpc` + api_broadcast.push( + timestamp=str(self.timestamp), + broadcast_type=str(self.broadcast_type), + data_hex=self.data_hex, + signature_hex=self.signature_hex, + ) + self.assertTrue(self.query.exists()) + + def test_invalid_signature(self): + with self.assertRaises(golem_messages.exceptions.InvalidSignature): + api_broadcast.push( + timestamp=str(self.timestamp), + broadcast_type=str(self.broadcast_type), + data_hex=self.data_hex, + signature_hex='7cf206f88696700f1a6f87c4a99a4bf11e8526a860f2a9d32345a3c1f9a95d985e1878ef60495e8deca4032d5622ffa02a3f059248084d07aee4dd4effead64501', # noqa pylint: disable=line-too-long + ) + self.assertFalse(self.query.exists()) + + def test_invalid_signature_invalid_hex(self): + with self.assertRaises(ValueError): + api_broadcast.push( + timestamp=str(self.timestamp), + broadcast_type=str(self.broadcast_type), + data_hex=self.data_hex, + signature_hex='bubliboo', + ) + self.assertFalse(self.query.exists()) + + +@mock.patch.object(active, 'BROADCAST_PUBKEY', PUBLIC_KEY) +class ListTest(BroadcastTestBase): + @freeze_time("2018-01-01 00:00:00") + def test_basic(self): + api_broadcast.push( + timestamp=self.timestamp, + broadcast_type=self.broadcast_type, + data_hex=self.data_hex, + signature_hex=self.signature_hex, + ) + result = api_broadcast.list_() + self.assertEqual( + result, + [ + { + 'timestamp': self.timestamp, + 'broadcast_type': self.broadcast_type, + 'broadcast_type_name': 'Version', + 'data_hex': self.data_hex, + 'created_date': 1514764800, + }, + ] + ) diff --git a/tests/golem/task/server/test_queue.py b/tests/golem/task/server/test_queue.py index 9d952f98e4..7d0d58decf 100644 --- a/tests/golem/task/server/test_queue.py +++ b/tests/golem/task/server/test_queue.py @@ -41,7 +41,7 @@ def setUp(self): def test_conn_established(self, *_): self.server.msg_queue_connection_established( - self.session, + mock.MagicMock(session=self.session), self.conn_id, self.node_id, )