Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT TO V4] Add timeout for WebsocketProvider #1119

Merged
merged 4 commits into from
Oct 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@
identity,
)

from .utils import (
get_open_port,
)


@pytest.fixture(scope="module", params=[lambda x: to_bytes(hexstr=x), identity])
def address_conversion_func(request):
return request.param


@pytest.fixture()
def open_port():
return get_open_port()
49 changes: 49 additions & 0 deletions tests/core/providers/test_websocket_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
import asyncio
from concurrent.futures import (
TimeoutError,
)
import pytest
from threading import (
Thread,
)

import websockets

from tests.utils import (
wait_for_ws,
)
from web3 import Web3
from web3.exceptions import (
ValidationError,
)
Expand All @@ -8,6 +21,42 @@
)


@pytest.yield_fixture
def start_websocket_server(open_port):
event_loop = asyncio.new_event_loop()

def run_server():
async def empty_server(websocket, path):
data = await websocket.recv()
await asyncio.sleep(0.02)
await websocket.send(data)
server = websockets.serve(empty_server, '127.0.0.1', open_port, loop=event_loop)
event_loop.run_until_complete(server)
event_loop.run_forever()

thd = Thread(target=run_server)
thd.start()
try:
yield
finally:
event_loop.call_soon_threadsafe(event_loop.stop)


@pytest.fixture()
def w3(open_port, start_websocket_server):
# need new event loop as the one used by server is already running
event_loop = asyncio.new_event_loop()
endpoint_uri = 'ws://127.0.0.1:{}'.format(open_port)
event_loop.run_until_complete(wait_for_ws(endpoint_uri, event_loop))
provider = WebsocketProvider(endpoint_uri, websocket_timeout=0.01)
return Web3(provider)


def test_websocket_provider_timeout(w3):
with pytest.raises(TimeoutError):
w3.eth.accounts


def test_restricted_websocket_kwargs():
invalid_kwargs = {'uri': 'ws://127.0.0.1:8546'}
re_exc_message = r'.*found: {0}*'.format(set(invalid_kwargs.keys()))
Expand Down
11 changes: 3 additions & 8 deletions tests/generate_go_ethereum_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
to_wei,
)

from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.module_testing.emitter_contract import (
EMITTER_ABI,
Expand Down Expand Up @@ -100,14 +103,6 @@ def tempdir():
shutil.rmtree(dir_path)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


def get_geth_binary():
from geth.install import (
get_executable_path,
Expand Down
8 changes: 0 additions & 8 deletions tests/integration/generate_fixtures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ def tempdir():
shutil.rmtree(dir_path)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


def get_geth_binary():
from geth.install import (
get_executable_path,
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/generate_fixtures/go_ethereum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
)

import common
from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.module_testing.emitter_contract import (
EMITTER_ABI,
Expand Down Expand Up @@ -42,7 +45,7 @@ def generate_go_ethereum_fixture(destination_dir):
geth_ipc_path_dir = stack.enter_context(common.tempdir())
geth_ipc_path = os.path.join(geth_ipc_path_dir, 'geth.ipc')

geth_port = common.get_open_port()
geth_port = get_open_port()
geth_binary = common.get_geth_binary()

geth_proc = stack.enter_context(common.get_geth_process( # noqa: F841
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/generate_fixtures/parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

import common
import go_ethereum
from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.toolz import (
merge,
Expand Down Expand Up @@ -176,7 +179,7 @@ def generate_parity_fixture(destination_dir):

geth_datadir = stack.enter_context(common.tempdir())

geth_port = common.get_open_port()
geth_port = get_open_port()

geth_ipc_path_dir = stack.enter_context(common.tempdir())
geth_ipc_path = os.path.join(geth_ipc_path_dir, 'geth.ipc')
Expand Down Expand Up @@ -221,7 +224,7 @@ def generate_parity_fixture(destination_dir):
parity_ipc_path_dir = stack.enter_context(common.tempdir())
parity_ipc_path = os.path.join(parity_ipc_path_dir, 'jsonrpc.ipc')

parity_port = common.get_open_port()
parity_port = get_open_port()
parity_binary = get_parity_binary()

parity_proc = stack.enter_context(get_parity_process( # noqa: F841
Expand Down
9 changes: 0 additions & 9 deletions tests/integration/go_ethereum/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import socket

from web3.utils.module_testing import (
EthModuleTest,
Expand All @@ -10,14 +9,6 @@
)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


class GoEthereumTest(Web3ModuleTest):
def _check_web3_clientVersion(self, client_version):
assert client_version.startswith('Geth/')
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

from tests.utils import (
get_open_port,
)
from web3 import Web3

from .common import (
Expand All @@ -8,7 +11,6 @@
GoEthereumPersonalModuleTest,
GoEthereumTest,
GoEthereumVersionModuleTest,
get_open_port,
)
from .utils import (
wait_for_http,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/go_ethereum/test_goethereum_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import pytest
import tempfile

from tests.utils import (
get_open_port,
)
from web3 import Web3

from .common import (
Expand All @@ -12,7 +15,6 @@
GoEthereumVersionModuleTest,
)
from .utils import (
get_open_port,
wait_for_socket,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/go_ethereum/test_goethereum_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from tests.integration.common import (
MiscWebsocketTest,
)
from tests.integration.utils import (
from tests.utils import (
get_open_port,
wait_for_ws,
)
from web3 import Web3
Expand All @@ -14,7 +15,6 @@
GoEthereumPersonalModuleTest,
GoEthereumTest,
GoEthereumVersionModuleTest,
get_open_port,
)


Expand Down
8 changes: 0 additions & 8 deletions tests/integration/go_ethereum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,3 @@ def kill_proc_gracefully(proc):
if proc.poll() is None:
proc.kill()
wait_for_popen(proc, 2)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)
9 changes: 0 additions & 9 deletions tests/integration/parity/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import socket

from flaky import (
flaky,
Expand All @@ -16,14 +15,6 @@
MAX_FLAKY_RUNS = 3


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


class ParityWeb3ModuleTest(Web3ModuleTest):
def _check_web3_clientVersion(self, client_version):
assert client_version.startswith('Parity/')
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/parity/test_parity_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from tests.integration.parity.utils import (
wait_for_http,
)
from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.module_testing import (
NetModuleTest,
Expand All @@ -15,7 +18,6 @@
ParityPersonalModuleTest,
ParityTraceModuleTest,
ParityWeb3ModuleTest,
get_open_port,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/integration/parity/test_parity_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from tests.integration.common import (
MiscWebsocketTest,
)
from tests.integration.utils import (
from tests.utils import (
get_open_port,
wait_for_ws,
)
from web3 import Web3
Expand All @@ -18,7 +19,6 @@
ParityPersonalModuleTest,
ParityTraceModuleTest,
ParityWeb3ModuleTest,
get_open_port,
)


Expand Down
8 changes: 0 additions & 8 deletions tests/integration/parity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,3 @@ def kill_proc_gracefully(proc):
if proc.poll() is None:
proc.kill()
wait_for_popen(proc, 2)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)
9 changes: 9 additions & 0 deletions tests/integration/utils.py → tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import asyncio
import socket
import time

import websockets


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


async def wait_for_ws(endpoint_uri, event_loop, timeout=60):
start = time.time()
while time.time() < start + timeout:
Expand Down
21 changes: 18 additions & 3 deletions web3/providers/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

RESTRICTED_WEBSOCKET_KWARGS = {'uri', 'loop'}
DEFAULT_WEBSOCKET_TIMEOUT = 10


def _start_event_loop(loop):
Expand Down Expand Up @@ -63,8 +64,14 @@ class WebsocketProvider(JSONBaseProvider):
logger = logging.getLogger("web3.providers.WebsocketProvider")
_loop = None

def __init__(self, endpoint_uri=None, websocket_kwargs=None):
def __init__(
self,
endpoint_uri=None,
websocket_kwargs=None,
websocket_timeout=DEFAULT_WEBSOCKET_TIMEOUT
):
self.endpoint_uri = endpoint_uri
self.websocket_timeout = websocket_timeout
if self.endpoint_uri is None:
self.endpoint_uri = get_default_endpoint()
if WebsocketProvider._loop is None:
Expand All @@ -90,8 +97,16 @@ def __str__(self):

async def coro_make_request(self, request_data):
async with self.conn as conn:
await conn.send(request_data)
return json.loads(await conn.recv())
await asyncio.wait_for(
conn.send(request_data),
timeout=self.websocket_timeout
)
return json.loads(
await asyncio.wait_for(
conn.recv(),
timeout=self.websocket_timeout
)
)

def make_request(self, method, params):
self.logger.debug("Making request WebSocket. URI: %s, "
Expand Down