Skip to content

Commit

Permalink
Apply ipv8 dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Oct 11, 2021
1 parent 3ab4f33 commit a016b30
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def request_tags(self):
if self.request_controller:
self.request_controller.register_peer(peer, REQUESTED_TAGS_COUNT)
self.logger.info(f'Request {REQUESTED_TAGS_COUNT} tags')
self.ez_send(peer, RequestTagOperationMessage(REQUESTED_TAGS_COUNT))
self.ez_send(peer, RequestTagOperationMessage(count=REQUESTED_TAGS_COUNT))

@lazy_wrapper(TagOperationMessage)
def on_message(self, peer, payload):
Expand All @@ -67,7 +67,7 @@ def on_message(self, peer, payload):
if self.crypto:
self.crypto.validate_signature(payload)
with db_session():
self.db.add_tag_operation(payload.infohash, payload.tag.decode(), payload.operation, payload.time,
self.db.add_tag_operation(payload.infohash, payload.tag, payload.operation, payload.time,
payload.creator_public_key, payload.signature)
self.logger.info(f'Tag added: {payload.tag}:{payload.infohash}')

Expand All @@ -90,9 +90,14 @@ def on_request(self, peer, payload):
self.logger.debug(f'Response {len(random_tag_operations)} tags')
for tag_operation in random_tag_operations:
try:
payload = TagOperationMessage(tag_operation.torrent_tag.torrent.infohash, tag_operation.operation,
tag_operation.time, tag_operation.peer.public_key,
tag_operation.signature, tag_operation.torrent_tag.tag.name.encode())
payload = TagOperationMessage(
infohash=tag_operation.torrent_tag.torrent.infohash,
operation=tag_operation.operation,
time=tag_operation.time,
creator_public_key=tag_operation.peer.public_key,
signature=tag_operation.signature,
tag=tag_operation.torrent_tag.tag.name,
)

if self.validator:
self.validator.validate_message(payload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,21 @@
from ipv8.messaging.serialization import default_serializer
from ipv8.types import Key
from tribler_core.components.tag.community.tag_payload import TagOperationMessage
from tribler_core.components.tag.db.tag_db import TagOperation


class TagCrypto:
@staticmethod
def _pack(message: VariablePayload) -> bytes:
def _pack(message: TagOperationMessage) -> bytes:
""" Pack a message to bytes by using default ipv8 serializer
"""
to_pack = copy.copy(message)
to_pack.signature = b'' # this field is excluded from signing
return default_serializer.pack_serializable(to_pack)

@staticmethod
def sign(infohash: bytes, tag: str, operation: TagOperation, time: int, creator_public_key: bytes,
key: Key) -> bytes:
def sign(message: TagOperationMessage, key: Key) -> bytes:
""" Sign arguments by using peer's private key
"""
message = TagOperationMessage(infohash, operation, time, creator_public_key, b'', tag.encode())
return default_eccrypto.create_signature(key, TagCrypto._pack(message))

def validate_signature(self, message: TagOperationMessage):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from ipv8.messaging.lazy_payload import VariablePayload, vp_compile
from ipv8.messaging.payload_dataclass import dataclass, overwrite_dataclass, type_from_format

dataclass = overwrite_dataclass(dataclass)

@vp_compile
class RequestTagOperationMessage(VariablePayload):
msg_id = 1

format_list = ['I']
names = ['count']
@dataclass(msg_id=1)
class RequestTagOperationMessage:
count: int


@vp_compile
class TagOperationMessage(VariablePayload):
msg_id = 2

format_list = ['20s', 'I', 'I', '74s', '64s', 'raw']
names = ['infohash', 'operation', 'time', 'creator_public_key', 'signature', 'tag']
@dataclass(msg_id=2)
class TagOperationMessage:
infohash: type_from_format('20s')
operation: int
time: int
creator_public_key: type_from_format('74s')
signature: type_from_format('64s')
tag: str
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from ipv8.messaging.lazy_payload import VariablePayload
from tribler_common.tag_constants import MAX_TAG_LENGTH, MIN_TAG_LENGTH
from tribler_core.components.tag.community.tag_payload import TagOperationMessage
from tribler_core.components.tag.db.tag_db import TagOperation


class TagValidator:
@staticmethod
def validate_message(message: VariablePayload):
tag: str = message.tag.decode()
tag_length = len(tag)
def validate_message(message: TagOperationMessage):
tag_length = len(message.tag)
if not MIN_TAG_LENGTH <= tag_length <= MAX_TAG_LENGTH:
raise ValueError('Tag length should be in range [3..50]')

if any(ch.isupper() for ch in tag):
if any(ch.isupper() for ch in message.tag):
raise ValueError('Tag should not contain upper-case letters')

# try to convert operation into Enum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ipv8.test.mocking.ipv8 import MockIPv8
from tribler_core.components.tag.community.tag_community import TagCommunity
from tribler_core.components.tag.community.tag_crypto import TagCrypto
from tribler_core.components.tag.community.tag_payload import TagOperationMessage
from tribler_core.components.tag.community.tag_request_controller import PeerValidationError, TagRequestController
from tribler_core.components.tag.community.tag_validator import TagValidator
from tribler_core.components.tag.db.tag_db import TagDatabase, TagOperation
Expand All @@ -33,19 +34,24 @@ async def fill_db(self):
# first 5 of them are correct
# next 5 of them are incorrect
for i in range(10):
infohash = f'{i}'.encode() * 20
tag = f'{i}' * 3
operation = TagOperation.ADD
time = 1
creator_public_key = self.overlay(0).my_peer.public_key.key_to_bin()
signature = TagCrypto.sign(infohash, tag, operation, time, creator_public_key,
key=self.overlay(0).my_peer.key)
tag_operation = TagOperationMessage(
infohash=f'{i}'.encode() * 20,
operation=TagOperation.ADD,
time=1,
creator_public_key=self.overlay(0).my_peer.public_key.key_to_bin(),
signature=f'{1}'.encode() * 64,
tag=f'{i}' * 3
)

tag_operation.signature = TagCrypto.sign(tag_operation, key=self.overlay(0).my_peer.key)

# 5 of them are signed incorrectly
if i >= 5:
signature = f'{i}'.encode() * 64
tag_operation.signature = f'{i}'.encode() * 64

self.overlay(0).db.add_tag_operation(infohash, tag, operation, time, creator_public_key, signature)
self.overlay(0).db.add_tag_operation(tag_operation.infohash, tag_operation.tag, tag_operation.operation,
tag_operation.time, tag_operation.creator_public_key,
tag_operation.signature)

# put them into the past
for tag_op in self.overlay(0).db.instance.TorrentTagOp.select():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ def fixture_tag_crypto():

@pytest.fixture(name="random_message") # this workaround implemented only for pylint
def fixture_random_message():
return TagOperationMessage(f'{1}'.encode() * 20, TagOperation.ADD, 1, f'{1}'.encode() * 74, f'{1}'.encode() * 64,
''.encode())
return TagOperationMessage(
infohash=f'{1}'.encode() * 20,
operation=TagOperation.ADD,
time=1,
creator_public_key=f'{1}'.encode() * 74,
signature=f'{1}'.encode() * 64,
tag='',
)


@pytest.fixture(name="key") # this workaround implemented only for pylint
Expand All @@ -37,27 +43,19 @@ async def test_pack(tag_crypto: TagCrypto, random_message: TagOperationMessage):
assert random_message.signature


async def test_is_signature_valid(tag_crypto: TagCrypto, key: Key):
infohash = f'{1}'.encode() * 20
operation = TagOperation.ADD
time = 1
creator_public_key = key.pub().key_to_bin()
tag = 'tag'
async def test_is_signature_valid(random_message: TagOperationMessage, tag_crypto: TagCrypto, key: Key):
random_message.creator_public_key = key.pub().key_to_bin()
random_message.tag = 'tag'

message = TagOperationMessage(infohash, operation, time, creator_public_key, b'', tag.encode())
message.signature = tag_crypto.sign(infohash, tag, operation, time, creator_public_key, key)
tag_crypto.validate_signature(message)
random_message.signature = tag_crypto.sign(random_message, key)
tag_crypto.validate_signature(random_message)


async def test_is_signature_invalid(tag_crypto: TagCrypto, key: Key):
infohash = f'{1}'.encode() * 20
operation = TagOperation.ADD
time = 1
creator_public_key = key.pub().key_to_bin()
tag = 'tag'
async def test_is_signature_invalid(random_message: TagOperationMessage, tag_crypto: TagCrypto, key: Key):
random_message.creator_public_key = key.pub().key_to_bin()
random_message.tag = 'tag'

message = TagOperationMessage(infohash, operation, time, creator_public_key, b'', tag.encode())
message.signature = tag_crypto.sign(infohash, tag, operation, time, creator_public_key, key)
random_message.signature = tag_crypto.sign(random_message, key)
with pytest.raises(InvalidSignature):
message.tag = 'changed_tag'.encode()
tag_crypto.validate_signature(message)
random_message.tag = 'changed_tag'
tag_crypto.validate_signature(random_message)
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

from ipv8.messaging.lazy_payload import VariablePayload
from tribler_core.components.tag.community.tag_payload import TagOperationMessage
from tribler_core.components.tag.community.tag_validator import TagValidator
from tribler_core.components.tag.db.tag_db import TagOperation
Expand All @@ -16,8 +15,15 @@ def fixture_validator():

def create_message(infohash=b'infohash', operation: int = TagOperation.ADD, time: int = 0,
creator_public_key=b'creator_public_key', signature=b'signature',
tag: str = 'tag') -> VariablePayload:
return TagOperationMessage(infohash, operation, time, creator_public_key, signature, tag.encode())
tag: str = 'tag') -> TagOperationMessage:
return TagOperationMessage(
infohash=infohash,
operation=operation,
time=time,
creator_public_key=creator_public_key,
signature=signature,
tag=tag
)


async def test_correct_tag_size(validator: TagValidator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ipv8.types import Key
from tribler_common.tag_constants import MAX_TAG_LENGTH, MIN_TAG_LENGTH
from tribler_core.components.tag.community.tag_crypto import TagCrypto
from tribler_core.components.tag.community.tag_payload import TagOperationMessage
from tribler_core.components.tag.db.tag_db import TagDatabase, TagOperation
from tribler_core.restapi.rest_endpoint import HTTP_BAD_REQUEST, RESTEndpoint, RESTResponse
from tribler_core.restapi.schema import HandledErrorSchema
Expand Down Expand Up @@ -85,6 +86,14 @@ def modify_tags(self, infohash: bytes, new_tags: Set[str]):
for tag in added_tags.union(removed_tags):
operation = TagOperation.ADD if tag in added_tags else TagOperation.REMOVE
t = self.tags_db.get_last_time_of_operation(infohash, tag, public_key) + 1
signature = TagCrypto.sign(infohash, tag, operation, t, public_key, self.key)
tag_operation_message = TagOperationMessage(
infohash=infohash,
operation=operation,
time=t,
creator_public_key=public_key,
signature=b'',
tag=tag
)
signature = TagCrypto.sign(tag_operation_message, self.key)
self.tags_db.add_tag_operation(infohash, tag, operation, t, public_key, signature,
is_local_peer=True)

0 comments on commit a016b30

Please sign in to comment.