88
99import websockets
1010
11+ from web3 .exceptions import (
12+ ValidationError ,
13+ )
1114from web3 .providers .base import (
1215 JSONBaseProvider ,
1316)
1417
18+ RESTRICTED_WEBSOCKET_KWARGS = {'uri' , 'loop' }
19+
1520
1621def _start_event_loop (loop ):
1722 asyncio .set_event_loop (loop )
@@ -32,15 +37,17 @@ def get_default_endpoint():
3237
3338class PersistentWebSocket :
3439
35- def __init__ (self , endpoint_uri , loop , ** kwargs ):
40+ def __init__ (self , endpoint_uri , loop , websocket_kwargs ):
3641 self .ws = None
3742 self .endpoint_uri = endpoint_uri
3843 self .loop = loop
39- self .kwargs = kwargs
44+ self .websocket_kwargs = websocket_kwargs
4045
4146 async def __aenter__ (self ):
4247 if self .ws is None :
43- self .ws = await websockets .connect (uri = self .endpoint_uri , loop = self .loop , ** self .kwargs )
48+ self .ws = await websockets .connect (
49+ uri = self .endpoint_uri , loop = self .loop , ** self .websocket_kwargs
50+ )
4451 return self .ws
4552
4653 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
@@ -56,13 +63,26 @@ class WebsocketProvider(JSONBaseProvider):
5663 logger = logging .getLogger ("web3.providers.WebsocketProvider" )
5764 _loop = None
5865
59- def __init__ (self , endpoint_uri = None , ** kwargs ):
66+ def __init__ (self , endpoint_uri = None , websocket_kwargs = None ):
6067 self .endpoint_uri = endpoint_uri
6168 if self .endpoint_uri is None :
6269 self .endpoint_uri = get_default_endpoint ()
6370 if WebsocketProvider ._loop is None :
6471 WebsocketProvider ._loop = _get_threaded_loop ()
65- self .conn = PersistentWebSocket (self .endpoint_uri , WebsocketProvider ._loop , ** kwargs )
72+ if websocket_kwargs is None :
73+ websocket_kwargs = {}
74+ else :
75+ found_restricted_keys = set (websocket_kwargs .keys ()).intersection (
76+ RESTRICTED_WEBSOCKET_KWARGS
77+ )
78+ if found_restricted_keys :
79+ raise ValidationError (
80+ '{0} are not allowed in websocket_kwargs, '
81+ 'found: {1}' .format (RESTRICTED_WEBSOCKET_KWARGS , found_restricted_keys )
82+ )
83+ self .conn = PersistentWebSocket (
84+ self .endpoint_uri , WebsocketProvider ._loop , websocket_kwargs
85+ )
6686 super ().__init__ ()
6787
6888 def __str__ (self ):
0 commit comments