Skip to content
This repository was archived by the owner on Mar 13, 2022. It is now read-only.

binary websocket streams #52

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
def stream(func, *args, **kwargs):
"""Stream given API call using websocket"""

binary = kwargs.pop('binary', False)

def _intercept_request_call(*args, **kwargs):
# old generated code's api client has config. new ones has
# configuration
try:
config = func.__self__.api_client.configuration
except AttributeError:
config = func.__self__.api_client.config
kwargs['binary'] = binary

return ws_client.websocket_call(config, *args, **kwargs)

Expand Down
32 changes: 21 additions & 11 deletions stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


class WSClient:
def __init__(self, configuration, url, headers):
def __init__(self, configuration, url, headers, binary=False):
"""A websocket client with support for channels.

Exec command uses different channels for different streams. for
Expand All @@ -41,7 +41,8 @@ def __init__(self, configuration, url, headers):
header = []
self._connected = False
self._channels = {}
self._all = ""
self._all = b"" if binary else ''
self.binary = binary

# We just need to pass the Authorization, ignore all the other
# http headers we get from the generated code
Expand Down Expand Up @@ -98,8 +99,8 @@ def readline_channel(self, channel, timeout=None):
while self.is_open() and time.time() - start < timeout:
if channel in self._channels:
data = self._channels[channel]
if "\n" in data:
index = data.find("\n")
if b"\n" if self.binary else '\n' in data:
index = data.find(b"\n" if self.binary else '\n')
ret = data[:index]
data = data[index+1:]
if data:
Expand All @@ -111,7 +112,9 @@ def readline_channel(self, channel, timeout=None):

def write_channel(self, channel, data):
"""Write data to a channel."""
self.sock.send(chr(channel) + data)
if type(data) != bytes:
data = bytes(data, 'utf-8')
self.sock.send(bytearray((channel, )) + data)

def peek_stdout(self, timeout=0):
"""Same as peek_channel with channel=1."""
Expand Down Expand Up @@ -147,7 +150,7 @@ def read_all(self):
channels mapped for each input.
"""
out = self._all
self._all = ""
self._all = b"" if self.binary else ''
self._channels = {}
return out

Expand Down Expand Up @@ -175,12 +178,15 @@ def update(self, timeout=0):
return
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
data = frame.data
if six.PY3:
data = data.decode("utf-8")
if len(data) > 1:
channel = ord(data[0])
if six.PY2:
channel = ord(data[0])
else:
channel = int(data[0])
data = data[1:]
if data:
if not self.binary and six.PY3:
data = data.decode('utf-8')
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
# keeping all messages in the order they received for
# non-blocking call.
Expand Down Expand Up @@ -232,6 +238,7 @@ def websocket_call(configuration, *args, **kwargs):
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
headers = kwargs.get("headers")
binary = kwargs.pop("binary", False)

# Expand command parameter list to indivitual command params
query_params = []
Expand All @@ -246,10 +253,13 @@ def websocket_call(configuration, *args, **kwargs):
url += '?' + urlencode(query_params)

try:
client = WSClient(configuration, get_websocket_url(url), headers)
client = WSClient(configuration, get_websocket_url(url), headers, binary)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
return WSResponse('%s' % ''.join(client.read_all()))
if binary:
return WSResponse(b''.join(client.read_all()))
else:
return WSResponse('%s' % ''.join(client.read_all()))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))