Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: basic websocket support #65

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pook/interceptors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from .websocket import WebSocketInterceptor
from .urllib3 import Urllib3Interceptor
from .http import HTTPClientInterceptor
from .base import BaseInterceptor
Expand All @@ -7,6 +8,7 @@
__all__ = (
'interceptors', 'add', 'get',
'BaseInterceptor',
'WebSocketInterceptor',
'Urllib3Interceptor',
'HTTPClientInterceptor',
'AIOHTTPInterceptor',
Expand All @@ -15,7 +17,8 @@
# Store built-in interceptors in pook.
interceptors = [
Urllib3Interceptor,
HTTPClientInterceptor
HTTPClientInterceptor,
WebSocketInterceptor
]

# Import aiohttp in modern Python runtimes
Expand Down
110 changes: 110 additions & 0 deletions pook/interceptors/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import tempfile
from ..request import Request
from .base import BaseInterceptor

# Support Python 2/3
try:
import mock
except Exception:
from unittest import mock


class MockFrame(object):
def __init__(self, opcode, data, fin=1):
self.opcode = opcode
self.data = data
self.fin = fin


class MockHandshakeResponse(object):

def __init__(self, status, headers, subprotocol):
self.status = status
self.headers = headers
self.subprotocol = subprotocol


class WebSocketInterceptor(BaseInterceptor):

def _handshake(self, sock, hostname, port, resource, **options):
return MockHandshakeResponse(200, {}, "")

def _connect(self, url, options, proxy, socket):
req = Request()
req.headers = {} # TODO
req.url = url

# TODO does this work multithreaded?!?
self._mock = self.engine.match(req)
Copy link
Owner

@h2non h2non Jan 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will not work. Mutex synchronization is required to make it safe. However, I feel that's another feature that has to be implemented in a separate PR.

if not self._mock:
# we cannot forward, as we have mocked away the connection
raise ValueError(
"Request to '%s' could not be matched or forwarded" % url
)

# make the body always a list to simplify our lives
body = self._mock._response._body
if not isinstance(body, list):
self._mock._response._body = [body]

# We have to return a valid file descriptor, not just a random integer.
# Source: https://docs.python.org/3/library/select.html#select.select
sock = tempfile.TemporaryFile()
addr = ("hostname", "port", "resource")

return sock, addr

def _send(self, data):
# mock as if all data has been sent
return len(data)

def _recv_frame(self):
# alias
body = self._mock._response._body

idx = getattr(self, "_data_index", 0)
if len(body) <= idx:
# close frame
return MockFrame(0x8, None)

# data frame
self._data_index = idx + 1
return MockFrame(0x1, body[idx].encode("utf-8"), 1)

def _patch(self, path, handler):
try:
# Create a new patcher for Urllib3 urlopen function
# used as entry point for all the HTTP communications
patcher = mock.patch(path, handler)
# Retrieve original patched function that we might need for real
# networking
# request = patcher.get_original()[0]
# Start patching function calls
patcher.start()
except Exception:
# Exceptions may accur due to missing package
# Ignore all the exceptions for now
pass
else:
self.patchers.append(patcher)

def activate(self):
"""
Activates the traffic interceptor.
This method must be implemented by any interceptor.
"""
patches = [
('websocket._core.connect', self._connect),
('websocket._core.handshake', self._handshake),
('websocket.WebSocket._send', self._send),
('websocket.WebSocket.recv_frame', self._recv_frame),
]

[self._patch(path, handler) for path, handler in patches]

def disable(self):
"""
Disables the traffic interceptor.
This method must be implemented by any interceptor.
"""
[patch.stop() for patch in self.patchers]
4 changes: 2 additions & 2 deletions pook/matchers/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
else: # Python 3
from urllib.parse import urlparse

# URI protocol test regular expression
protoregex = re.compile('^http[s]?://', re.IGNORECASE)
# URI protocol test regular expression (without group capture)
protoregex = re.compile('^(?:ws|http)[s]?://', re.IGNORECASE)


class URLMatcher(BaseMatcher):
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ urllib3>=1.24.2
bumpversion~=0.5.3
aiohttp~=1.1.5 ; python_version >= '3.4.2'
mocket~=1.6.0
websocket-client~=0.57.0
13 changes: 13 additions & 0 deletions tests/unit/interceptors/websocket_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from websocket import WebSocket
import pook


@pook.on
def test_websocket():
(pook.get('ws://some-non-existing.org')
.reply(204)
.body('test'))

socket = WebSocket()
socket.connect('ws://some-non-existing.org/', header=['x-custom: header'])
socket.send('test')
3 changes: 3 additions & 0 deletions tests/unit/matchers/url_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_url_matcher_urlparse():
('http://foo.com/foo/bar', 'http://foo.com/foo/bar', True),
('http://foo.com/foo/bar/baz', 'http://foo.com/foo/bar/baz', True),
('http://foo.com/foo?x=y&z=w', 'http://foo.com/foo?x=y&z=w', True),
('ws://foo.com', 'ws://foo.com', True),
('wss://foo.com', 'wss://foo.com', True),

# Invalid cases
('http://foo.com', 'http://bar.com', False),
Expand All @@ -41,6 +43,7 @@ def test_url_matcher_urlparse():
('http://foo.com/foo/bar', 'http://foo.com/bar/foo', False),
('http://foo.com/foo/bar/baz', 'http://foo.com/baz/bar/foo', False),
('http://foo.com/foo?x=y&z=w', 'http://foo.com/foo?x=x&y=y', False),
('ws://foo.com', 'wss://foo.com', False),
))


Expand Down