diff --git a/stream/stream.py b/stream/stream.py index 6d5f05f8..9bb59017 100644 --- a/stream/stream.py +++ b/stream/stream.py @@ -12,26 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import ws_client - +import functools -def stream(func, *args, **kwargs): - """Stream given API call using websocket. - Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()""" - - def _intercept_request_call(*args, **kwargs): - # old generated code's api client has config. new ones has - # configuration - try: - config = func.__self__.api_client.configuration - except AttributeError: - config = func.__self__.api_client.config +from . import ws_client - return ws_client.websocket_call(config, *args, **kwargs) - prev_request = func.__self__.api_client.request +def _websocket_reqeust(websocket_request, api_method, *args, **kwargs): + """Override the ApiClient.request method with an alternative websocket based + method and call the supplied Kubernetes API method with that in place.""" + api_client = api_method.__self__.api_client + # old generated code's api client has config. new ones has configuration try: - func.__self__.api_client.request = _intercept_request_call - return func(*args, **kwargs) + configuration = api_client.configuration + except AttributeError: + configuration = api_client.config + prev_request = api_client.request + try: + api_client.request = functools.partial(websocket_request, configuration) + return api_method(*args, **kwargs) finally: - func.__self__.api_client.request = prev_request + api_client.request = prev_request + + +stream = functools.partial(_websocket_reqeust, ws_client.websocket_call) diff --git a/stream/ws_client.py b/stream/ws_client.py index 2b599381..fa7f393e 100644 --- a/stream/ws_client.py +++ b/stream/ws_client.py @@ -23,7 +23,7 @@ import six import yaml -from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse +from six.moves.urllib.parse import urlencode, urlparse, urlunparse from six import StringIO from websocket import WebSocket, ABNF, enableTrace @@ -51,47 +51,13 @@ def __init__(self, configuration, url, headers, capture_all): like port forwarding can forward different pods' streams to different channels. """ - enableTrace(False) - header = [] self._connected = False self._channels = {} if capture_all: self._all = StringIO() else: self._all = _IgnoredIO() - - # We just need to pass the Authorization, ignore all the other - # http headers we get from the generated code - if headers and 'authorization' in headers: - header.append("authorization: %s" % headers['authorization']) - - if headers and 'sec-websocket-protocol' in headers: - header.append("sec-websocket-protocol: %s" % - headers['sec-websocket-protocol']) - else: - header.append("sec-websocket-protocol: v4.channel.k8s.io") - - if url.startswith('wss://') and configuration.verify_ssl: - ssl_opts = { - 'cert_reqs': ssl.CERT_REQUIRED, - 'ca_certs': configuration.ssl_ca_cert or certifi.where(), - } - if configuration.assert_hostname is not None: - ssl_opts['check_hostname'] = configuration.assert_hostname - else: - ssl_opts = {'cert_reqs': ssl.CERT_NONE} - - if configuration.cert_file: - ssl_opts['certfile'] = configuration.cert_file - if configuration.key_file: - ssl_opts['keyfile'] = configuration.key_file - - self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False) - if configuration.proxy: - proxy_url = urlparse(configuration.proxy) - self.sock.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port) - else: - self.sock.connect(url, header=header) + self.sock = create_websocket(configuration, url, headers) self._connected = True def peek_channel(self, channel, timeout=0): @@ -259,41 +225,77 @@ def close(self, **kwargs): WSResponse = collections.namedtuple('WSResponse', ['data']) -def get_websocket_url(url): +def get_websocket_url(url, query_params=None): parsed_url = urlparse(url) parts = list(parsed_url) if parsed_url.scheme == 'http': parts[0] = 'ws' elif parsed_url.scheme == 'https': parts[0] = 'wss' + if query_params: + query = [] + for key, value in query_params: + if key == 'command' and isinstance(value, list): + for command in value: + query.append((key, command)) + else: + query.append((key, value)) + if query: + parts[4] = urlencode(query) return urlunparse(parts) -def websocket_call(configuration, *args, **kwargs): +def create_websocket(configuration, url, headers=None): + enableTrace(False) + + # We just need to pass the Authorization, ignore all the other + # http headers we get from the generated code + header = [] + if headers and 'authorization' in headers: + header.append("authorization: %s" % headers['authorization']) + if headers and 'sec-websocket-protocol' in headers: + header.append("sec-websocket-protocol: %s" % + headers['sec-websocket-protocol']) + else: + header.append("sec-websocket-protocol: v4.channel.k8s.io") + + if url.startswith('wss://') and configuration.verify_ssl: + ssl_opts = { + 'cert_reqs': ssl.CERT_REQUIRED, + 'ca_certs': configuration.ssl_ca_cert or certifi.where(), + } + if configuration.assert_hostname is not None: + ssl_opts['check_hostname'] = configuration.assert_hostname + else: + ssl_opts = {'cert_reqs': ssl.CERT_NONE} + + if configuration.cert_file: + ssl_opts['certfile'] = configuration.cert_file + if configuration.key_file: + ssl_opts['keyfile'] = configuration.key_file + + websocket = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False) + if configuration.proxy: + proxy_url = urlparse(configuration.proxy) + websocket.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port) + else: + websocket.connect(url, header=header) + return websocket + + +def websocket_call(configuration, _method, url, **kwargs): """An internal function to be called in api-client when a websocket - connection is required. args and kwargs are the parameters of + connection is required. method, url, and kwargs are the parameters of apiClient.request method.""" - url = args[1] + url = get_websocket_url(url, kwargs.get("query_params")) + headers = kwargs.get("headers") _request_timeout = kwargs.get("_request_timeout", 60) _preload_content = kwargs.get("_preload_content", True) capture_all = kwargs.get("capture_all", True) - headers = kwargs.get("headers") - - # Expand command parameter list to indivitual command params - query_params = [] - for key, value in kwargs.get("query_params", {}): - if key == 'command' and isinstance(value, list): - for command in value: - query_params.append((key, command)) - else: - query_params.append((key, value)) - - if query_params: - url += '?' + urlencode(query_params) try: - client = WSClient(configuration, get_websocket_url(url), headers, capture_all) + client = WSClient(configuration, url, headers, capture_all) if not _preload_content: return client client.run_forever(timeout=_request_timeout)