diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..a3c489b --- /dev/null +++ b/main.py @@ -0,0 +1,118 @@ +from nostr.client.client import NostrClient +from nostr.event import Event +from nostr.key import PublicKey +import asyncio +import threading +import time +import datetime + + +def print_status(client): + print("") + for relay in client.relay_manager.relays.values(): + connected_text = "🟢" if relay.connected else "🔴" + status_text = f"{connected_text} ⬆️ {relay.num_sent_events} ⬇️ {relay.num_received_events} ⚠️ {relay.error_counter} ⏱️ {relay.ping} ms - {relay.url.split('//')[1]}" + print(status_text) + + +async def dm(): + print("This is an example NIP-04 DM flow") + pk = input("Enter your privatekey to post from (enter nothing for a random one): ") + + def callback(event: Event, decrypted_content): + """ + Callback to trigger when a DM is received. + """ + print( + f"\nFrom {event.public_key[:3]}..{event.public_key[-3:]}: {decrypted_content}" + ) + + client = NostrClient(private_key=pk) + if not pk: + print(f"Your private key: {client.private_key.bech32()}") + + print(f"Your public key: {client.public_key.bech32()}") + + t = threading.Thread( + target=client.get_dm, args=(client.public_key, callback), daemon=True + ) + t.start() + + pubkey_to_str = ( + input("Enter other pubkey to DM to (enter nothing to DM yourself): ") + or client.public_key.hex() + ) + if pubkey_to_str.startswith("npub"): + pubkey_to = PublicKey().from_npub(pubkey_to_str) + else: + pubkey_to = PublicKey(bytes.fromhex(pubkey_to_str)) + print(f"Sending DMs to {pubkey_to.bech32()}") + while True: + print_status(client) + await asyncio.sleep(1) + msg = input("\nEnter message: ") + client.dm(msg, pubkey_to) + + +async def post(): + print("This posts and reads a nostr note") + pk = input("Enter your privatekey to post from (enter nothing for a random one): ") + + def callback(event: Event): + """ + Callback to trigger when post appers. + """ + print( + f"\nFrom {event.public_key[:3]}..{event.public_key[-3:]}: {event.content}" + ) + + sender_client = NostrClient(private_key=pk) + # await asyncio.sleep(1) + + pubkey_to_str = ( + input( + "Enter other pubkey (enter nothing to read your own posts, enter * for all): " + ) + or sender_client.public_key.hex() + ) + if pubkey_to_str == "*": + pubkey_to = None + elif pubkey_to_str.startswith("npub"): + pubkey_to = PublicKey().from_npub(pubkey_to_str) + else: + pubkey_to = PublicKey(bytes.fromhex(pubkey_to_str)) + + print(f"Subscribing to posts by {pubkey_to.bech32() if pubkey_to else 'everyone'}") + + filters = { + "since": int( + time.mktime( + (datetime.datetime.now() - datetime.timedelta(hours=1)).timetuple() + ) + ) + } + + t = threading.Thread( + target=sender_client.get_post, + args=( + pubkey_to, + callback, + filters, + ), + daemon=True, + ) + t.start() + + while True: + print_status(sender_client) + await asyncio.sleep(1) + msg = input("\nEnter post: ") + sender_client.post(msg) + + +if input("Enter '1' for DM, '2' for Posts (Default: 1): ") == "2": + # make a post and subscribe to posts + asyncio.run(post()) +else: + # write a DM and receive DMs + asyncio.run(dm()) diff --git a/nostr/client/__init__.py b/nostr/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nostr/client/cbc.py b/nostr/client/cbc.py new file mode 100644 index 0000000..a41dbc0 --- /dev/null +++ b/nostr/client/cbc.py @@ -0,0 +1,41 @@ + +from Cryptodome import Random +from Cryptodome.Cipher import AES + +plain_text = "This is the text to encrypts" + +# encrypted = "7mH9jq3K9xNfWqIyu9gNpUz8qBvGwsrDJ+ACExdV1DvGgY8q39dkxVKeXD7LWCDrPnoD/ZFHJMRMis8v9lwHfNgJut8EVTMuJJi8oTgJevOBXl+E+bJPwej9hY3k20rgCQistNRtGHUzdWyOv7S1tg==".encode() +# iv = "GzDzqOVShWu3Pl2313FBpQ==".encode() + +key = bytes.fromhex("3aa925cb69eb613e2928f8a18279c78b1dca04541dfd064df2eda66b59880795") + +BLOCK_SIZE = 16 + +class AESCipher(object): + """This class is compatible with crypto.createCipheriv('aes-256-cbc') + + """ + def __init__(self, key=None): + self.key = key + + def pad(self, data): + length = BLOCK_SIZE - (len(data) % BLOCK_SIZE) + return data + (chr(length) * length).encode() + + def unpad(self, data): + return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))] + + def encrypt(self, plain_text): + cipher = AES.new(self.key, AES.MODE_CBC) + b = plain_text.encode("UTF-8") + return cipher.iv, cipher.encrypt(self.pad(b)) + + def decrypt(self, iv, enc_text): + cipher = AES.new(self.key, AES.MODE_CBC, iv=iv) + return self.unpad(cipher.decrypt(enc_text).decode("UTF-8")) + +if __name__ == "__main__": + aes = AESCipher(key=key) + iv, enc_text = aes.encrypt(plain_text) + dec_text = aes.decrypt(iv, enc_text) + print(dec_text) \ No newline at end of file diff --git a/nostr/client/client.py b/nostr/client/client.py new file mode 100644 index 0000000..74bafea --- /dev/null +++ b/nostr/client/client.py @@ -0,0 +1,167 @@ +from typing import * +import ssl +import time +import json +import os +import base64 + +from ..event import Event +from ..relay_manager import RelayManager +from ..message_type import ClientMessageType +from ..key import PrivateKey, PublicKey + +from ..filter import Filter, Filters +from ..event import Event, EventKind, EncryptedDirectMessage +from ..relay_manager import RelayManager +from ..message_type import ClientMessageType + +# from aes import AESCipher +from . import cbc + + +class NostrClient: + relays = [ + # "wss://eagerporpoise9.lnbits.com/nostrclient/api/v1/relay", + "wss://localhost:5001/nostrclient/api/v1/relay", + # "wss://nostr-pub.wellorder.net", + # "wss://relay.damus.io", + # "wss://nostr.zebedee.cloud", + # "wss://relay.snort.social", + # "wss://nostr.fmt.wiz.biz", + # "wss://nos.lol", + # "wss://nostr.oxtr.dev", + # "wss://relay.current.fyi", + # "wss://relay.snort.social", + ] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr" + relay_manager = RelayManager() + private_key: PrivateKey + public_key: PublicKey + + def __init__(self, private_key: str = "", relays: List[str] = [], connect=True): + self.generate_keys(private_key) + + if len(relays): + self.relays = relays + if connect: + self.connect() + + def connect(self): + for relay in self.relays: + self.relay_manager.add_relay(relay) + self.relay_manager.open_connections( + {"cert_reqs": ssl.CERT_NONE} + ) # NOTE: This disables ssl certificate verification + + def close(self): + self.relay_manager.close_connections() + + def generate_keys(self, private_key: str = None): + if private_key.startswith("nsec"): + self.private_key = PrivateKey.from_nsec(private_key) + elif private_key: + self.private_key = PrivateKey(bytes.fromhex(private_key)) + else: + self.private_key = PrivateKey() # generate random key + self.public_key = self.private_key.public_key + + def post(self, message: str): + event = Event(message, self.public_key.hex(), kind=EventKind.TEXT_NOTE) + self.private_key.sign_event(event) + event_json = event.to_message() + # print("Publishing message:") + # print(event_json) + self.relay_manager.publish_message(event_json) + + def get_post( + self, sender_publickey: PublicKey = None, callback_func=None, filter_kwargs={} + ): + filter = Filter( + authors=[sender_publickey.hex()] if sender_publickey else None, + kinds=[EventKind.TEXT_NOTE], + **filter_kwargs, + ) + filters = Filters([filter]) + subscription_id = os.urandom(4).hex() + self.relay_manager.add_subscription(subscription_id, filters) + + request = [ClientMessageType.REQUEST, subscription_id] + request.extend(filters.to_json_array()) + message = json.dumps(request) + # print(message) + self.relay_manager.publish_message(message) + + while True: + while self.relay_manager.message_pool.has_events(): + event_msg = self.relay_manager.message_pool.get_event() + if callback_func: + callback_func(event_msg.event) + time.sleep(0.1) + + def dm(self, message: str, to_pubkey: PublicKey): + dm = EncryptedDirectMessage( + recipient_pubkey=to_pubkey.hex(), cleartext_content=message + ) + self.private_key.sign_event(dm) + # print(dm) + self.relay_manager.publish_event(dm) + + def get_dm(self, sender_publickey: PublicKey, callback_func=None, filter_kwargs={}): + filters = Filters( + [ + Filter( + kinds=[EventKind.ENCRYPTED_DIRECT_MESSAGE], + pubkey_refs=[sender_publickey.hex()], + **filter_kwargs, + ) + ] + ) + subscription_id = os.urandom(4).hex() + self.relay_manager.add_subscription(subscription_id, filters) + + request = [ClientMessageType.REQUEST, subscription_id] + request.extend(filters.to_json_array()) + message = json.dumps(request) + self.relay_manager.publish_message(message) + # print(message) + while True: + while self.relay_manager.message_pool.has_events(): + event_msg = self.relay_manager.message_pool.get_event() + if "?iv=" in event_msg.event.content: + try: + shared_secret = self.private_key.compute_shared_secret( + event_msg.event.public_key + ) + aes = cbc.AESCipher(key=shared_secret) + enc_text_b64, iv_b64 = event_msg.event.content.split("?iv=") + iv = base64.decodebytes(iv_b64.encode("utf-8")) + enc_text = base64.decodebytes(enc_text_b64.encode("utf-8")) + dec_text = aes.decrypt(iv, enc_text) + if callback_func: + callback_func(event_msg.event, dec_text) + except: + pass + break + time.sleep(0.1) + + def subscribe( + self, + callback_events_func=None, + callback_notices_func=None, + callback_eosenotices_func=None, + ): + while True: + while self.relay_manager.message_pool.has_events(): + event_msg = self.relay_manager.message_pool.get_event() + print(event_msg.event.content) + if callback_events_func: + callback_events_func(event_msg) + while self.relay_manager.message_pool.has_notices(): + event_msg = self.relay_manager.message_pool.has_notices() + if callback_notices_func: + callback_notices_func(event_msg) + while self.relay_manager.message_pool.has_eose_notices(): + event_msg = self.relay_manager.message_pool.get_eose_notice() + if callback_eosenotices_func: + callback_eosenotices_func(event_msg) + + time.sleep(0.1) diff --git a/nostr/event.py b/nostr/event.py index 11f56c6..b903e0e 100644 --- a/nostr/event.py +++ b/nostr/event.py @@ -3,11 +3,10 @@ from dataclasses import dataclass, field from enum import IntEnum from typing import List -from secp256k1 import PrivateKey, PublicKey +from secp256k1 import PublicKey from hashlib import sha256 -from nostr.message_type import ClientMessageType - +from .message_type import ClientMessageType class EventKind(IntEnum): @@ -19,17 +18,17 @@ class EventKind(IntEnum): DELETE = 5 - @dataclass class Event: content: str = None public_key: str = None created_at: int = None kind: int = EventKind.TEXT_NOTE - tags: List[List[str]] = field(default_factory=list) # Dataclasses require special handling when the default value is a mutable type + tags: List[List[str]] = field( + default_factory=list + ) # Dataclasses require special handling when the default value is a mutable type signature: str = None - def __post_init__(self): if self.content is not None and not isinstance(self.content, str): # DMs initialize content to None but all other kinds should pass in a str @@ -38,39 +37,44 @@ def __post_init__(self): if self.created_at is None: self.created_at = int(time.time()) - @staticmethod - def serialize(public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str) -> bytes: + def serialize( + public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str + ) -> bytes: data = [0, public_key, created_at, kind, tags, content] - data_str = json.dumps(data, separators=(',', ':'), ensure_ascii=False) + data_str = json.dumps(data, separators=(",", ":"), ensure_ascii=False) return data_str.encode() - @staticmethod - def compute_id(public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str): - return sha256(Event.serialize(public_key, created_at, kind, tags, content)).hexdigest() - + def compute_id( + public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str + ): + return sha256( + Event.serialize(public_key, created_at, kind, tags, content) + ).hexdigest() @property def id(self) -> str: # Always recompute the id to reflect the up-to-date state of the Event - return Event.compute_id(self.public_key, self.created_at, self.kind, self.tags, self.content) - - - def add_pubkey_ref(self, pubkey:str): - """ Adds a reference to a pubkey as a 'p' tag """ - self.tags.append(['p', pubkey]) - + return Event.compute_id( + self.public_key, self.created_at, self.kind, self.tags, self.content + ) - def add_event_ref(self, event_id:str): - """ Adds a reference to an event_id as an 'e' tag """ - self.tags.append(['e', event_id]) + def add_pubkey_ref(self, pubkey: str): + """Adds a reference to a pubkey as a 'p' tag""" + self.tags.append(["p", pubkey]) + def add_event_ref(self, event_id: str): + """Adds a reference to an event_id as an 'e' tag""" + self.tags.append(["e", event_id]) def verify(self) -> bool: - pub_key = PublicKey(bytes.fromhex("02" + self.public_key), True) # add 02 for schnorr (bip340) - return pub_key.schnorr_verify(bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True) - + pub_key = PublicKey( + bytes.fromhex("02" + self.public_key), True + ) # add 02 for schnorr (bip340) + return pub_key.schnorr_verify( + bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True + ) def to_message(self) -> str: return json.dumps( @@ -83,20 +87,18 @@ def to_message(self) -> str: "kind": self.kind, "tags": self.tags, "content": self.content, - "sig": self.signature - } + "sig": self.signature, + }, ] ) - @dataclass class EncryptedDirectMessage(Event): recipient_pubkey: str = None cleartext_content: str = None reference_event_id: str = None - def __post_init__(self): if self.content is not None: self.cleartext_content = self.content @@ -115,9 +117,10 @@ def __post_init__(self): if self.reference_event_id is not None: self.add_event_ref(self.reference_event_id) - @property def id(self) -> str: if self.content is None: - raise Exception("EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field") + raise Exception( + "EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field" + ) return super().id diff --git a/nostr/filter.py b/nostr/filter.py index f4cb0a5..f119079 100644 --- a/nostr/filter.py +++ b/nostr/filter.py @@ -4,7 +4,6 @@ from .event import Event, EventKind - class Filter: """ NIP-01 filtering. @@ -16,20 +15,26 @@ class Filter: added. For example: # arbitrary tag filter.add_arbitrary_tag('t', [hashtags]) - + # promoted to explicit support Filter(hashtag_refs=[hashtags]) """ + def __init__( - self, - event_ids: List[str] = None, - kinds: List[EventKind] = None, - authors: List[str] = None, - since: int = None, - until: int = None, - event_refs: List[str] = None, # the "#e" attr; list of event ids referenced in an "e" tag - pubkey_refs: List[str] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag - limit: int = None) -> None: + self, + event_ids: List[str] = None, + kinds: List[EventKind] = None, + authors: List[str] = None, + since: int = None, + until: int = None, + event_refs: List[ + str + ] = None, # the "#e" attr; list of event ids referenced in an "e" tag + pubkey_refs: List[ + str + ] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag + limit: int = None, + ) -> None: self.event_ids = event_ids self.kinds = kinds self.authors = authors @@ -41,21 +46,19 @@ def __init__( self.tags = {} if self.event_refs: - self.add_arbitrary_tag('e', self.event_refs) + self.add_arbitrary_tag("e", self.event_refs) if self.pubkey_refs: - self.add_arbitrary_tag('p', self.pubkey_refs) - + self.add_arbitrary_tag("p", self.pubkey_refs) def add_arbitrary_tag(self, tag: str, values: list): """ - Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12 - single-letter tags. + Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12 + single-letter tags. """ - # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#" + # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#" tag_key = tag if len(tag) > 1 else f"#{tag}" self.tags[tag_key] = values - def matches(self, event: Event) -> bool: if self.event_ids is not None and event.id not in self.event_ids: return False @@ -67,7 +70,9 @@ def matches(self, event: Event) -> bool: return False if self.until is not None and event.created_at > self.until: return False - if (self.event_refs is not None or self.pubkey_refs is not None) and len(event.tags) == 0: + if (self.event_refs is not None or self.pubkey_refs is not None) and len( + event.tags + ) == 0: return False if self.tags: @@ -79,7 +84,7 @@ def matches(self, event: Event) -> bool: if f_tag not in e_tag_identifiers: # Event is missing a tag type that we're looking for return False - + # Multiple values within f_tag_values are treated as OR search; an Event # needs to match only one. # Note: an Event could have multiple entries of the same tag type @@ -94,12 +99,11 @@ def matches(self, event: Event) -> bool: return True - def to_json_object(self) -> dict: res = {} if self.event_ids is not None: res["ids"] = self.event_ids - if self.kinds is not None: + if self.kinds is not None: res["kinds"] = self.kinds if self.authors is not None: res["authors"] = self.authors @@ -115,9 +119,8 @@ def to_json_object(self) -> dict: return res - class Filters(UserList): - def __init__(self, initlist: "list[Filter]"=[]) -> None: + def __init__(self, initlist: "list[Filter]" = []) -> None: super().__init__(initlist) self.data: "list[Filter]" @@ -128,4 +131,4 @@ def match(self, event: Event): return False def to_json_array(self) -> list: - return [filter.to_json_object() for filter in self.data] \ No newline at end of file + return [filter.to_json_object() for filter in self.data] diff --git a/nostr/key.py b/nostr/key.py index 350c72d..6988964 100644 --- a/nostr/key.py +++ b/nostr/key.py @@ -6,13 +6,13 @@ from cryptography.hazmat.primitives import padding from hashlib import sha256 -from nostr.delegation import Delegation -from nostr.event import EncryptedDirectMessage, Event, EventKind +from .delegation import Delegation +from .event import EncryptedDirectMessage, Event, EventKind from . import bech32 class PublicKey: - def __init__(self, raw_bytes: bytes) -> None: + def __init__(self, raw_bytes: bytes = None) -> None: self.raw_bytes = raw_bytes def bech32(self) -> str: @@ -28,14 +28,14 @@ def verify_signed_message_hash(self, hash: str, sig: str) -> bool: @classmethod def from_npub(cls, npub: str): - """ Load a PublicKey from its bech32/npub form """ + """Load a PublicKey from its bech32/npub form""" hrp, data, spec = bech32.bech32_decode(npub) raw_public_key = bech32.convertbits(data, 5, 8)[:-1] return cls(bytes(raw_public_key)) class PrivateKey: - def __init__(self, raw_secret: bytes=None) -> None: + def __init__(self, raw_secret: bytes = None) -> None: if not raw_secret is None: self.raw_secret = raw_secret else: @@ -46,7 +46,7 @@ def __init__(self, raw_secret: bytes=None) -> None: @classmethod def from_nsec(cls, nsec: str): - """ Load a PrivateKey from its bech32/nsec form """ + """Load a PrivateKey from its bech32/nsec form""" hrp, data, spec = bech32.bech32_decode(nsec) raw_secret = bech32.convertbits(data, 5, 8)[:-1] return cls(bytes(raw_secret)) @@ -71,22 +71,28 @@ def encrypt_message(self, message: str, public_key_hex: str) -> str: padded_data = padder.update(message.encode()) + padder.finalize() iv = secrets.token_bytes(16) - cipher = Cipher(algorithms.AES(self.compute_shared_secret(public_key_hex)), modes.CBC(iv)) + cipher = Cipher( + algorithms.AES(self.compute_shared_secret(public_key_hex)), modes.CBC(iv) + ) encryptor = cipher.encryptor() encrypted_message = encryptor.update(padded_data) + encryptor.finalize() return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}" - + def encrypt_dm(self, dm: EncryptedDirectMessage) -> None: - dm.content = self.encrypt_message(message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey) + dm.content = self.encrypt_message( + message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey + ) def decrypt_message(self, encoded_message: str, public_key_hex: str) -> str: - encoded_data = encoded_message.split('?iv=') + encoded_data = encoded_message.split("?iv=") encoded_content, encoded_iv = encoded_data[0], encoded_data[1] iv = base64.b64decode(encoded_iv) - cipher = Cipher(algorithms.AES(self.compute_shared_secret(public_key_hex)), modes.CBC(iv)) + cipher = Cipher( + algorithms.AES(self.compute_shared_secret(public_key_hex)), modes.CBC(iv) + ) encrypted_content = base64.b64decode(encoded_content) decryptor = cipher.decryptor() @@ -110,7 +116,9 @@ def sign_event(self, event: Event) -> None: event.signature = self.sign_message_hash(bytes.fromhex(event.id)) def sign_delegation(self, delegation: Delegation) -> None: - delegation.signature = self.sign_message_hash(sha256(delegation.delegation_token.encode()).digest()) + delegation.signature = self.sign_message_hash( + sha256(delegation.delegation_token.encode()).digest() + ) def __eq__(self, other): return self.raw_secret == other.raw_secret @@ -122,9 +130,12 @@ def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey: while True: sk = PrivateKey() - if prefix is not None and not sk.public_key.bech32()[5:5+len(prefix)] == prefix: + if ( + prefix is not None + and not sk.public_key.bech32()[5 : 5 + len(prefix)] == prefix + ): continue - if suffix is not None and not sk.public_key.bech32()[-len(suffix):] == suffix: + if suffix is not None and not sk.public_key.bech32()[-len(suffix) :] == suffix: continue break @@ -132,7 +143,11 @@ def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey: ffi = FFI() -@ffi.callback("int (unsigned char *, const unsigned char *, const unsigned char *, void *)") + + +@ffi.callback( + "int (unsigned char *, const unsigned char *, const unsigned char *, void *)" +) def copy_x(output, x32, y32, data): ffi.memmove(output, x32, 32) return 1 diff --git a/nostr/message_pool.py b/nostr/message_pool.py index ac46b24..d364cf2 100644 --- a/nostr/message_pool.py +++ b/nostr/message_pool.py @@ -4,22 +4,26 @@ from .message_type import RelayMessageType from .event import Event + class EventMessage: def __init__(self, event: Event, subscription_id: str, url: str) -> None: self.event = event self.subscription_id = subscription_id self.url = url + class NoticeMessage: def __init__(self, content: str, url: str) -> None: self.content = content self.url = url + class EndOfStoredEventsMessage: def __init__(self, subscription_id: str, url: str) -> None: self.subscription_id = subscription_id self.url = url + class MessagePool: def __init__(self) -> None: self.events: Queue[EventMessage] = Queue() @@ -27,7 +31,7 @@ def __init__(self) -> None: self.eose_notices: Queue[EndOfStoredEventsMessage] = Queue() self._unique_events: set = set() self.lock: Lock = Lock() - + def add_message(self, message: str, url: str): self._process_message(message, url) @@ -55,7 +59,14 @@ def _process_message(self, message: str, url: str): if message_type == RelayMessageType.EVENT: subscription_id = message_json[1] e = message_json[2] - event = Event(e['pubkey'], e['content'], e['created_at'], e['kind'], e['tags'], e['id'], e['sig']) + event = Event( + e["content"], + e["pubkey"], + e["created_at"], + e["kind"], + e["tags"], + e["sig"], + ) with self.lock: if not event.id in self._unique_events: self.events.put(EventMessage(event, subscription_id, url)) @@ -64,5 +75,3 @@ def _process_message(self, message: str, url: str): self.notices.put(NoticeMessage(message_json[1], url)) elif message_type == RelayMessageType.END_OF_STORED_EVENTS: self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url)) - - diff --git a/nostr/relay.py b/nostr/relay.py index 373a259..670480c 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -1,4 +1,6 @@ import json +import time +from queue import Queue from threading import Lock from websocket import WebSocketApp from .event import Event @@ -7,49 +9,90 @@ from .message_type import RelayMessageType from .subscription import Subscription + class RelayPolicy: - def __init__(self, should_read: bool=True, should_write: bool=True) -> None: + def __init__(self, should_read: bool = True, should_write: bool = True) -> None: self.should_read = should_read self.should_write = should_write def to_json_object(self) -> dict[str, bool]: - return { - "read": self.should_read, - "write": self.should_write - } + return {"read": self.should_read, "write": self.should_write} + class Relay: def __init__( - self, - url: str, - policy: RelayPolicy, - message_pool: MessagePool, - subscriptions: dict[str, Subscription]={}) -> None: + self, + url: str, + policy: RelayPolicy, + message_pool: MessagePool, + subscriptions: dict[str, Subscription] = {}, + ) -> None: self.url = url self.policy = policy self.message_pool = message_pool self.subscriptions = subscriptions + self.connected: bool = False + self.reconnect: bool = True + self.error_counter: int = 0 + self.error_threshold: int = 0 + self.num_received_events: int = 0 + self.num_sent_events: int = 0 + self.num_subscriptions: int = 0 + self.ssl_options: dict = {} + self.proxy: dict = {} self.lock = Lock() + self.queue = Queue() self.ws = WebSocketApp( url, on_open=self._on_open, on_message=self._on_message, on_error=self._on_error, - on_close=self._on_close) - - def connect(self, ssl_options: dict=None, proxy: dict=None): - self.ws.run_forever( - sslopt=ssl_options, - http_proxy_host=None if proxy is None else proxy.get('host'), - http_proxy_port=None if proxy is None else proxy.get('port'), - proxy_type=None if proxy is None else proxy.get('type') + on_close=self._on_close, + on_ping=self._on_ping, + on_pong=self._on_pong, ) + def connect(self, ssl_options: dict = None, proxy: dict = None): + self.ssl_options = ssl_options + self.proxy = proxy + if not self.connected: + self.ws.run_forever( + sslopt=ssl_options, + http_proxy_host=None if proxy is None else proxy.get("host"), + http_proxy_port=None if proxy is None else proxy.get("port"), + proxy_type=None if proxy is None else proxy.get("type"), + ping_interval=5, + ) + def close(self): self.ws.close() + def check_reconnect(self): + try: + self.close() + except: + pass + self.connected = False + if self.reconnect: + time.sleep(1) + self.connect(self.ssl_options, self.proxy) + + @property + def ping(self): + ping_ms = int((self.ws.last_pong_tm - self.ws.last_ping_tm) * 1000) + return ping_ms if self.connected and ping_ms > 0 else 0 + def publish(self, message: str): - self.ws.send(message) + self.queue.put(message) + + def queue_worker(self): + while True: + if self.connected: + message = self.queue.get() + self.num_sent_events += 1 + self.ws.send(message) + else: + time.sleep(0.1) def add_subscription(self, id, filters: Filters): with self.lock: @@ -57,7 +100,7 @@ def add_subscription(self, id, filters: Filters): def close_subscription(self, id: str) -> None: with self.lock: - self.subscriptions.pop(id) + self.subscriptions.pop(id, None) def update_subscription(self, id: str, filters: Filters) -> None: with self.lock: @@ -68,25 +111,42 @@ def to_json_object(self) -> dict: return { "url": self.url, "policy": self.policy.to_json_object(), - "subscriptions": [subscription.to_json_object() for subscription in self.subscriptions.values()] + "subscriptions": [ + subscription.to_json_object() + for subscription in self.subscriptions.values() + ], } def _on_open(self, class_obj): + self.connected = True pass def _on_close(self, class_obj, status_code, message): + self.connected = False pass def _on_message(self, class_obj, message: str): if self._is_valid_message(message): + self.num_received_events += 1 self.message_pool.add_message(message, self.url) - + def _on_error(self, class_obj, error): - pass + self.connected = False + self.error_counter += 1 + if self.error_threshold and self.error_counter > self.error_threshold: + pass + else: + self.check_reconnect() + + def _on_ping(self, class_obj, message): + return + + def _on_pong(self, class_obj, message): + return def _is_valid_message(self, message: str) -> bool: message = message.strip("\n") - if not message or message[0] != '[' or message[-1] != ']': + if not message or message[0] != "[" or message[-1] != "]": return False message_json = json.loads(message) @@ -96,21 +156,28 @@ def _is_valid_message(self, message: str) -> bool: if message_type == RelayMessageType.EVENT: if not len(message_json) == 3: return False - + subscription_id = message_json[1] with self.lock: if subscription_id not in self.subscriptions: return False e = message_json[2] - event = Event(e['pubkey'], e['content'], e['created_at'], e['kind'], e['tags'], e['id'], e['sig']) + event = Event( + e["content"], + e["pubkey"], + e["created_at"], + e["kind"], + e["tags"], + e["sig"], + ) if not event.verify(): return False with self.lock: subscription = self.subscriptions[subscription_id] - if not subscription.filters.match(event): + if subscription.filters and not subscription.filters.match(event): return False return True diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index 191f5bd..5b92d8d 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -8,18 +8,18 @@ from .relay import Relay, RelayPolicy - class RelayException(Exception): pass - class RelayManager: def __init__(self) -> None: self.relays: dict[str, Relay] = {} self.message_pool = MessagePool() - def add_relay(self, url: str, read: bool=True, write: bool=True, subscriptions={}): + def add_relay( + self, url: str, read: bool = True, write: bool = True, subscriptions={} + ): policy = RelayPolicy(read, write) relay = Relay(url, policy, self.message_pool, subscriptions) self.relays[url] = relay @@ -35,12 +35,17 @@ def close_subscription(self, id: str): for relay in self.relays.values(): relay.close_subscription(id) - def open_connections(self, ssl_options: dict=None, proxy: dict=None): + def open_connections(self, ssl_options: dict = None, proxy: dict = None): for relay in self.relays.values(): threading.Thread( target=relay.connect, args=(ssl_options, proxy), - name=f"{relay.url}-thread" + name=f"{relay.url}-thread", + daemon=True, + ).start() + + threading.Thread( + target=relay.queue_worker, name=f"{relay.url}-queue", daemon=True ).start() def close_connections(self): @@ -53,11 +58,12 @@ def publish_message(self, message: str): relay.publish(message) def publish_event(self, event: Event): - """ Verifies that the Event is publishable before submitting it to relays """ + """Verifies that the Event is publishable before submitting it to relays""" if event.signature is None: raise RelayException(f"Could not publish {event.id}: must be signed") if not event.verify(): - raise RelayException(f"Could not publish {event.id}: failed to verify signature {event.signature}") - + raise RelayException( + f"Could not publish {event.id}: failed to verify signature {event.signature}" + ) self.publish_message(event.to_message()) diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..4e433ab --- /dev/null +++ b/poetry.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Poetry and should not be changed by hand. +package = [] + +[metadata] +lock-version = "2.0" +python-versions = "*" +content-hash = "115cf985d932e9bf5f540555bbdd75decbb62cac81e399375fc19f6277f8c1d8" diff --git a/pyproject.toml b/pyproject.toml index 417a873..9d733ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,9 @@ +[tool.poetry] +name = "python-nostr" +version = "0.1.0" +description = "" +authors = ["Your Name "] + [build-system] requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" @@ -33,4 +39,4 @@ write_to = "nostr/_version.py" test = [ "pytest >=7.2.0", "pytest-cov[all]" -] \ No newline at end of file +]