Skip to content

Commit

Permalink
pydantic models
Browse files Browse the repository at this point in the history
  • Loading branch information
msbrogli committed Nov 3, 2023
1 parent b9b06e8 commit a340ba6
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 29 deletions.
56 changes: 27 additions & 29 deletions hathor/p2p/sync_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from hathor.p2p.sync_agent import SyncAgent
from hathor.p2p.sync_v2.blockchain_streaming_client import BlockchainStreamingClient, StreamingError
from hathor.p2p.sync_v2.mempool import SyncMempoolManager
from hathor.p2p.sync_v2.payloads import GetBestBlockPayload, GetNextBlocksPayload, GetTransactionsBFSPayload
from hathor.p2p.sync_v2.streamers import DEFAULT_STREAMING_LIMIT, BlockchainStreaming, StreamEnd, TransactionsStreaming
from hathor.p2p.sync_v2.transaction_streaming_client import TransactionStreamingClient
from hathor.transaction import BaseTransaction, Block, Transaction
Expand Down Expand Up @@ -609,12 +610,12 @@ def handle_peer_block_hashes(self, payload: str) -> None:
def send_get_next_blocks(self, start_hash: bytes, end_hash: bytes, quantity: int) -> None:
""" Send a PEER-BLOCK-HASHES message.
"""
payload = json.dumps(dict(
start_hash=start_hash.hex(),
end_hash=end_hash.hex(),
payload = GetNextBlocksPayload(
start_hash=start_hash,
end_hash=end_hash,
quantity=quantity,
))
self.send_message(ProtocolMessages.GET_NEXT_BLOCKS, payload)
)
self.send_message(ProtocolMessages.GET_NEXT_BLOCKS, payload.json())
self.receiving_stream = True

def handle_get_next_blocks(self, payload: str) -> None:
Expand All @@ -624,11 +625,11 @@ def handle_get_next_blocks(self, payload: str) -> None:
if self._is_streaming:
self.protocol.send_error_and_close_connection('GET-NEXT-BLOCKS received before previous one finished')
return
data = json.loads(payload)
data = GetNextBlocksPayload.parse_raw(payload)
self.send_next_blocks(
start_hash=bytes.fromhex(data['start_hash']),
end_hash=bytes.fromhex(data['end_hash']),
quantity=data['quantity'],
start_hash=data.start_hash,
end_hash=data.end_hash,
quantity=data.quantity,
)

def send_next_blocks(self, start_hash: bytes, end_hash: bytes, quantity: int) -> None:
Expand Down Expand Up @@ -755,22 +756,23 @@ def send_get_best_block(self) -> None:
"""
self.send_message(ProtocolMessages.GET_BEST_BLOCK)

def handle_get_best_block(self, payload: str) -> None:
def handle_get_best_block(self, _payload: str) -> None:
""" Handle a GET-BEST-BLOCK message.
"""
best_block = self.tx_storage.get_best_block()
meta = best_block.get_metadata()
assert meta.validation.is_fully_connected()
data = {'block': best_block.hash_hex, 'height': meta.height}
self.send_message(ProtocolMessages.BEST_BLOCK, json.dumps(data))
payload = GetBestBlockPayload(
block=best_block.hash,
height=meta.height,
)
self.send_message(ProtocolMessages.BEST_BLOCK, payload.json())

def handle_best_block(self, payload: str) -> None:
""" Handle a BEST-BLOCK message.
"""
data = json.loads(payload)
_id = bytes.fromhex(data['block'])
height = data['height']
best_block = _HeightInfo(height=height, id=_id)
data = GetBestBlockPayload.parse_raw(payload)
best_block = _HeightInfo(height=data.height, id=data.block)

deferred = self._deferred_best_block
self._deferred_best_block = None
Expand Down Expand Up @@ -808,12 +810,12 @@ def send_get_transactions_bfs(self,
start_from=start_from_hexlist,
first_block_hash=first_block_hash_hex,
last_block_hash=last_block_hash_hex)
payload = json.dumps(dict(
start_from=start_from_hexlist,
first_block_hash=first_block_hash_hex,
last_block_hash=last_block_hash_hex,
))
self.send_message(ProtocolMessages.GET_TRANSACTIONS_BFS, payload)
payload = GetTransactionsBFSPayload(
start_from=start_from,
first_block_hash=first_block_hash,
last_block_hash=last_block_hash,
)
self.send_message(ProtocolMessages.GET_TRANSACTIONS_BFS, payload.json())
self.receiving_stream = True

def handle_get_transactions_bfs(self, payload: str) -> None:
Expand All @@ -822,17 +824,13 @@ def handle_get_transactions_bfs(self, payload: str) -> None:
if self._is_streaming:
self.log.warn('ignore GET-TRANSACTIONS-BFS, already streaming')
return
data = json.loads(payload)
data = GetTransactionsBFSPayload.parse_raw(payload)
# XXX: todo verify this limit while parsing the payload.
start_from = [bytes.fromhex(h) for h in data['start_from']]
if len(start_from) > MAX_GET_TRANSACTIONS_BFS_LEN:
if len(data.start_from) > MAX_GET_TRANSACTIONS_BFS_LEN:
self.log.error('too many transactions in GET-TRANSACTIONS-BFS', state=self.state)
self.protocol.send_error_and_close_connection('Too many transactions in GET-TRANSACTIONS-BFS')
return
self.log.debug('handle_get_transactions_bfs', **data)
first_block_hash = bytes.fromhex(data['first_block_hash'])
last_block_hash = bytes.fromhex(data['last_block_hash'])
self.send_transactions_bfs(start_from, first_block_hash, last_block_hash)
self.send_transactions_bfs(data.start_from, data.first_block_hash, data.last_block_hash)

def send_transactions_bfs(self,
start_from: list[bytes],
Expand Down
66 changes: 66 additions & 0 deletions hathor/p2p/sync_v2/payloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2023 Hathor Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pydantic import validator

from hathor.utils.pydantic import BaseModel


class PayloadBaseModel(BaseModel):

@classmethod
def convert_hex_to_bytes(cls, value: str | bytes) -> bytes:
if isinstance(value, str):
return bytes.fromhex(value)
elif isinstance(value, bytes):
return value
raise ValueError('invalid type')

class Config:
json_encoders = {
bytes: lambda x: x.hex()
}


class GetNextBlocksPayload(PayloadBaseModel):
start_hash: bytes
end_hash: bytes
quantity: int

@validator('start_hash', 'end_hash', pre=True)
def validate_bytes_fields(cls, value: str | bytes) -> bytes:
return cls.convert_hex_to_bytes(value)


class GetBestBlockPayload(PayloadBaseModel):
block: bytes
height: int

@validator('block', pre=True)
def validate_bytes_fields(cls, value: str | bytes) -> bytes:
return cls.convert_hex_to_bytes(value)


class GetTransactionsBFSPayload(PayloadBaseModel):
start_from: list[bytes]
first_block_hash: bytes
last_block_hash: bytes

@validator('first_block_hash', 'last_block_hash', pre=True)
def validate_bytes_fields(cls, value: str | bytes) -> bytes:
return cls.convert_hex_to_bytes(value)

@validator('start_from', pre=True)
def validate_start_from(cls, values: list[str | bytes]) -> list[bytes]:
return [cls.convert_hex_to_bytes(x) for x in values]

0 comments on commit a340ba6

Please sign in to comment.