diff --git a/stream/stream.py b/stream/stream.py index 0412fc33..611a8589 100644 --- a/stream/stream.py +++ b/stream/stream.py @@ -16,6 +16,8 @@ def stream(func, *args, **kwargs): """Stream given API call using websocket""" + binary = kwargs.pop('binary', False) + def _intercept_request_call(*args, **kwargs): # old generated code's api client has config. new ones has # configuration @@ -23,6 +25,7 @@ def _intercept_request_call(*args, **kwargs): config = func.__self__.api_client.configuration except AttributeError: config = func.__self__.api_client.config + kwargs['binary'] = binary return ws_client.websocket_call(config, *args, **kwargs) diff --git a/stream/ws_client.py b/stream/ws_client.py index 1cc56cdd..0b7baa3f 100644 --- a/stream/ws_client.py +++ b/stream/ws_client.py @@ -29,7 +29,7 @@ class WSClient: - def __init__(self, configuration, url, headers): + def __init__(self, configuration, url, headers, binary=False): """A websocket client with support for channels. Exec command uses different channels for different streams. for @@ -41,7 +41,8 @@ def __init__(self, configuration, url, headers): header = [] self._connected = False self._channels = {} - self._all = "" + self._all = b"" if binary else '' + self.binary = binary # We just need to pass the Authorization, ignore all the other # http headers we get from the generated code @@ -98,8 +99,8 @@ def readline_channel(self, channel, timeout=None): while self.is_open() and time.time() - start < timeout: if channel in self._channels: data = self._channels[channel] - if "\n" in data: - index = data.find("\n") + if b"\n" if self.binary else '\n' in data: + index = data.find(b"\n" if self.binary else '\n') ret = data[:index] data = data[index+1:] if data: @@ -111,7 +112,9 @@ def readline_channel(self, channel, timeout=None): def write_channel(self, channel, data): """Write data to a channel.""" - self.sock.send(chr(channel) + data) + if type(data) != bytes: + data = bytes(data, 'utf-8') + self.sock.send(bytearray((channel, )) + data) def peek_stdout(self, timeout=0): """Same as peek_channel with channel=1.""" @@ -147,7 +150,7 @@ def read_all(self): channels mapped for each input. """ out = self._all - self._all = "" + self._all = b"" if self.binary else '' self._channels = {} return out @@ -175,12 +178,15 @@ def update(self, timeout=0): return elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT: data = frame.data - if six.PY3: - data = data.decode("utf-8") if len(data) > 1: - channel = ord(data[0]) + if six.PY2: + channel = ord(data[0]) + else: + channel = int(data[0]) data = data[1:] if data: + if not self.binary and six.PY3: + data = data.decode('utf-8') if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]: # keeping all messages in the order they received for # non-blocking call. @@ -232,6 +238,7 @@ def websocket_call(configuration, *args, **kwargs): _request_timeout = kwargs.get("_request_timeout", 60) _preload_content = kwargs.get("_preload_content", True) headers = kwargs.get("headers") + binary = kwargs.pop("binary", False) # Expand command parameter list to indivitual command params query_params = [] @@ -246,10 +253,13 @@ def websocket_call(configuration, *args, **kwargs): url += '?' + urlencode(query_params) try: - client = WSClient(configuration, get_websocket_url(url), headers) + client = WSClient(configuration, get_websocket_url(url), headers, binary) if not _preload_content: return client client.run_forever(timeout=_request_timeout) - return WSResponse('%s' % ''.join(client.read_all())) + if binary: + return WSResponse(b''.join(client.read_all())) + else: + return WSResponse('%s' % ''.join(client.read_all())) except (Exception, KeyboardInterrupt, SystemExit) as e: raise ApiException(status=0, reason=str(e))