Skip to content

Commit

Permalink
Merge pull request #2 from callebtc/connection_state
Browse files Browse the repository at this point in the history
Connection_state
  • Loading branch information
callebtc authored Jan 24, 2023
2 parents d7fb45f + 06362a4 commit 56deff4
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions nostr/relay.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time
from threading import Lock
from websocket import WebSocketApp
from .event import Event
Expand Down Expand Up @@ -29,6 +30,11 @@ def __init__(
self.policy = policy
self.message_pool = message_pool
self.subscriptions = subscriptions
self.connected: bool = False
self.reconnect: bool = True
self.error_counter: int = 0
self.error_threshold: int = 0
self.ssl_options: dict = {}
self.lock = Lock()
self.ws = WebSocketApp(
url,
Expand All @@ -38,14 +44,26 @@ def __init__(
on_close=self._on_close,
)

def connect(self, ssl_options: dict = None):
self.ws.run_forever(sslopt=ssl_options)
def connect(self, ssl_options: dict = {}):
self.ssl_options = ssl_options
self.ws.run_forever(sslopt=self.ssl_options)

def close(self):
self.ws.close()

def check_reconnect(self):
try:
self.close()
except:
pass
self.connected = False
if self.reconnect:
time.sleep(1)
self.connect(self.ssl_options)

def publish(self, message: str):
self.ws.send(message)
if self.connected:
self.ws.send(message)

def add_subscription(self, id, filters: Filters):
with self.lock:
Expand All @@ -71,17 +89,25 @@ def to_json_object(self) -> dict:
}

def _on_open(self, class_obj):
self.connected = True
pass

def _on_close(self, class_obj, status_code, message):
self.connected = False
self.check_reconnect()
pass

def _on_message(self, class_obj, message: str):
if self._is_valid_message(message):
self.message_pool.add_message(message, self.url)

def _on_error(self, class_obj, error):
pass
self.connected = False
self.error_counter += 1
if self.error_threshold and self.error_counter > self.error_threshold:
pass
else:
self.check_reconnect()

def _is_valid_message(self, message: str) -> bool:
message = message.strip("\n")
Expand Down Expand Up @@ -117,7 +143,7 @@ def _is_valid_message(self, message: str) -> bool:
with self.lock:
subscription = self.subscriptions[subscription_id]

if not subscription.filters.match(event):
if subscription.filters and not subscription.filters.match(event):
return False

return True

0 comments on commit 56deff4

Please sign in to comment.