Skip to content

Commit b675e5c

Browse files
committed
added docs and tests for websocket_kwargs
1 parent 7f1e6c3 commit b675e5c

File tree

6 files changed

+87
-6
lines changed

6 files changed

+87
-6
lines changed

docs/providers.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,15 +224,30 @@ IPCProvider
224224
WebsocketProvider
225225
~~~~~~~~~~~~~~~~~
226226
227-
.. py:class:: web3.providers.websocket.WebsocketProvider(endpoint_uri)
227+
.. py:class:: web3.providers.websocket.WebsocketProvider(endpoint_uri[, websocket_kwargs])
228228
229229
This provider handles interactions with an WS or WSS based JSON-RPC server.
230230
231+
* ``endpoint_uri`` should be the full URI to the RPC endpoint such as
232+
``'ws://localhost:8546'``.
233+
* ``websocket_kwargs`` this should be a dictionary of keyword arguments which
234+
will be passed onto the ws/wss websocket connection.
235+
231236
.. code-block:: python
232237
233238
>>> from web3 import Web3
234239
>>> web3 = Web3(Web3.WebsocketProvider("ws://127.0.0.1:8546")
235240
241+
Under the hood, the ``WebsocketProvider`` uses the python websockets library for
242+
making requests. If you would like to modify how requests are made, you can
243+
use the ``websocket_kwargs`` to do so. A common use case for this is increasing
244+
the timeout for each request.
245+
246+
247+
.. code-block:: python
248+
249+
>>> from web3 import Web3
250+
>>> web3 = Web3(Web3.WebsocketProvider("http://127.0.0.1:8546", websocket_kwargs={'timeout': 60}))
236251
237252
.. py:currentmodule:: web3.providers.eth_tester
238253
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from web3.exceptions import (
4+
ValidationError,
5+
)
6+
from web3.providers.websocket import (
7+
WebsocketProvider,
8+
)
9+
10+
11+
def test_restricted_websocket_kwargs():
12+
invalid_kwargs = {'uri': 'ws://127.0.0.1:8546'}
13+
re_exc_message = r'.*found: {0}*'.format(set(invalid_kwargs.keys()))
14+
with pytest.raises(ValidationError, match=re_exc_message):
15+
WebsocketProvider(websocket_kwargs=invalid_kwargs)

tests/integration/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
3+
from websockets.exceptions import (
4+
ConnectionClosed,
5+
)
6+
7+
from web3 import Web3
8+
9+
10+
class MiscWebsocketTests:
11+
12+
def test_websocket_max_size_error(self, web3, endpoint_uri):
13+
w3 = Web3(Web3.WebsocketProvider(
14+
endpoint_uri=endpoint_uri, websocket_kwargs={'max_size': 1})
15+
)
16+
with pytest.raises(ConnectionClosed):
17+
w3.eth.getBlock(0)

tests/integration/go_ethereum/test_goethereum_ws.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import pytest
22

3+
from tests.integration.common import (
4+
MiscWebsocketTests,
5+
)
36
from tests.integration.utils import (
47
wait_for_ws,
58
)
@@ -65,3 +68,7 @@ class TestGoEthereumNetModuleTest(GoEthereumNetModuleTest):
6568

6669
class TestGoEthereumPersonalModuleTest(GoEthereumPersonalModuleTest):
6770
pass
71+
72+
73+
class TestMiscWebsocketTests(MiscWebsocketTests):
74+
pass

tests/integration/parity/test_parity_ws.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
22
import pytest
33

4+
from tests.integration.common import (
5+
MiscWebsocketTests,
6+
)
47
from tests.integration.utils import (
58
wait_for_ws,
69
)
@@ -96,3 +99,7 @@ class TestParityPersonalModuleTest(ParityPersonalModuleTest):
9699

97100
class TestParityTraceModuleTest(ParityTraceModuleTest):
98101
pass
102+
103+
104+
class TestMiscWebsocketTests(MiscWebsocketTests):
105+
pass

web3/providers/websocket.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@
88

99
import websockets
1010

11+
from web3.exceptions import (
12+
ValidationError,
13+
)
1114
from web3.providers.base import (
1215
JSONBaseProvider,
1316
)
1417

18+
RESTRICTED_WEBSOCKET_KWARGS = {'uri', 'loop'}
19+
1520

1621
def _start_event_loop(loop):
1722
asyncio.set_event_loop(loop)
@@ -32,15 +37,17 @@ def get_default_endpoint():
3237

3338
class PersistentWebSocket:
3439

35-
def __init__(self, endpoint_uri, loop, **kwargs):
40+
def __init__(self, endpoint_uri, loop, websocket_kwargs):
3641
self.ws = None
3742
self.endpoint_uri = endpoint_uri
3843
self.loop = loop
39-
self.kwargs = kwargs
44+
self.websocket_kwargs = websocket_kwargs
4045

4146
async def __aenter__(self):
4247
if self.ws is None:
43-
self.ws = await websockets.connect(uri=self.endpoint_uri, loop=self.loop, **self.kwargs)
48+
self.ws = await websockets.connect(
49+
uri=self.endpoint_uri, loop=self.loop, **self.websocket_kwargs
50+
)
4451
return self.ws
4552

4653
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -56,13 +63,26 @@ class WebsocketProvider(JSONBaseProvider):
5663
logger = logging.getLogger("web3.providers.WebsocketProvider")
5764
_loop = None
5865

59-
def __init__(self, endpoint_uri=None, **kwargs):
66+
def __init__(self, endpoint_uri=None, websocket_kwargs=None):
6067
self.endpoint_uri = endpoint_uri
6168
if self.endpoint_uri is None:
6269
self.endpoint_uri = get_default_endpoint()
6370
if WebsocketProvider._loop is None:
6471
WebsocketProvider._loop = _get_threaded_loop()
65-
self.conn = PersistentWebSocket(self.endpoint_uri, WebsocketProvider._loop, **kwargs)
72+
if websocket_kwargs is None:
73+
websocket_kwargs = {}
74+
else:
75+
found_restricted_keys = set(websocket_kwargs.keys()).intersection(
76+
RESTRICTED_WEBSOCKET_KWARGS
77+
)
78+
if found_restricted_keys:
79+
raise ValidationError(
80+
'{0} are not allowed in websocket_kwargs, '
81+
'found: {1}'.format(RESTRICTED_WEBSOCKET_KWARGS, found_restricted_keys)
82+
)
83+
self.conn = PersistentWebSocket(
84+
self.endpoint_uri, WebsocketProvider._loop, websocket_kwargs
85+
)
6686
super().__init__()
6787

6888
def __str__(self):

0 commit comments

Comments
 (0)