Skip to content
This repository was archived by the owner on Mar 13, 2022. It is now read-only.

Commit fd62214

Browse files
committed
Refactor stream package to enable common method helpers for other streaming api classes.
1 parent 54d188f commit fd62214

File tree

2 files changed

+70
-66
lines changed

2 files changed

+70
-66
lines changed

stream/stream.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import types
16+
1517
from . import ws_client
1618

1719

1820
def stream(func, *args, **kwargs):
1921
"""Stream given API call using websocket.
2022
Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()"""
2123

22-
def _intercept_request_call(*args, **kwargs):
23-
# old generated code's api client has config. new ones has
24-
# configuration
25-
try:
26-
config = func.__self__.api_client.configuration
27-
except AttributeError:
28-
config = func.__self__.api_client.config
29-
30-
return ws_client.websocket_call(config, *args, **kwargs)
31-
32-
prev_request = func.__self__.api_client.request
24+
api_client = func.__self__.api_client
25+
prev_request = api_client.request
3326
try:
34-
func.__self__.api_client.request = _intercept_request_call
27+
api_client.request = types.MethodType(ws_client.websocket_call, api_client)
3528
return func(*args, **kwargs)
3629
finally:
37-
func.__self__.api_client.request = prev_request
30+
api_client.request = prev_request

stream/ws_client.py

+64-53
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import six
2424
import yaml
2525

26-
from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse
26+
from six.moves.urllib.parse import urlencode, urlparse, urlunparse
2727
from six import StringIO
2828

2929
from websocket import WebSocket, ABNF, enableTrace
@@ -51,47 +51,13 @@ def __init__(self, configuration, url, headers, capture_all):
5151
like port forwarding can forward different pods' streams to different
5252
channels.
5353
"""
54-
enableTrace(False)
55-
header = []
5654
self._connected = False
5755
self._channels = {}
5856
if capture_all:
5957
self._all = StringIO()
6058
else:
6159
self._all = _IgnoredIO()
62-
63-
# We just need to pass the Authorization, ignore all the other
64-
# http headers we get from the generated code
65-
if headers and 'authorization' in headers:
66-
header.append("authorization: %s" % headers['authorization'])
67-
68-
if headers and 'sec-websocket-protocol' in headers:
69-
header.append("sec-websocket-protocol: %s" %
70-
headers['sec-websocket-protocol'])
71-
else:
72-
header.append("sec-websocket-protocol: v4.channel.k8s.io")
73-
74-
if url.startswith('wss://') and configuration.verify_ssl:
75-
ssl_opts = {
76-
'cert_reqs': ssl.CERT_REQUIRED,
77-
'ca_certs': configuration.ssl_ca_cert or certifi.where(),
78-
}
79-
if configuration.assert_hostname is not None:
80-
ssl_opts['check_hostname'] = configuration.assert_hostname
81-
else:
82-
ssl_opts = {'cert_reqs': ssl.CERT_NONE}
83-
84-
if configuration.cert_file:
85-
ssl_opts['certfile'] = configuration.cert_file
86-
if configuration.key_file:
87-
ssl_opts['keyfile'] = configuration.key_file
88-
89-
self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
90-
if configuration.proxy:
91-
proxy_url = urlparse(configuration.proxy)
92-
self.sock.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port)
93-
else:
94-
self.sock.connect(url, header=header)
60+
self.sock = create_websocket(configuration, url, headers)
9561
self._connected = True
9662

9763
def peek_channel(self, channel, timeout=0):
@@ -259,41 +225,86 @@ def close(self, **kwargs):
259225
WSResponse = collections.namedtuple('WSResponse', ['data'])
260226

261227

262-
def get_websocket_url(url):
228+
def get_websocket_url(url, query_params=None):
263229
parsed_url = urlparse(url)
264230
parts = list(parsed_url)
265231
if parsed_url.scheme == 'http':
266232
parts[0] = 'ws'
267233
elif parsed_url.scheme == 'https':
268234
parts[0] = 'wss'
235+
if query_params:
236+
query = []
237+
for key, value in query_params:
238+
if key == 'command' and isinstance(value, list):
239+
for command in value:
240+
query.append((key, command))
241+
else:
242+
query.append((key, value))
243+
if query:
244+
parts[4] = urlencode(query)
269245
return urlunparse(parts)
270246

271247

272-
def websocket_call(configuration, *args, **kwargs):
248+
def create_websocket(configuration, url, headers=None):
249+
enableTrace(False)
250+
251+
# We just need to pass the Authorization, ignore all the other
252+
# http headers we get from the generated code
253+
header = []
254+
if headers and 'authorization' in headers:
255+
header.append("authorization: %s" % headers['authorization'])
256+
if headers and 'sec-websocket-protocol' in headers:
257+
header.append("sec-websocket-protocol: %s" %
258+
headers['sec-websocket-protocol'])
259+
else:
260+
header.append("sec-websocket-protocol: v4.channel.k8s.io")
261+
262+
if url.startswith('wss://') and configuration.verify_ssl:
263+
ssl_opts = {
264+
'cert_reqs': ssl.CERT_REQUIRED,
265+
'ca_certs': configuration.ssl_ca_cert or certifi.where(),
266+
}
267+
if configuration.assert_hostname is not None:
268+
ssl_opts['check_hostname'] = configuration.assert_hostname
269+
else:
270+
ssl_opts = {'cert_reqs': ssl.CERT_NONE}
271+
272+
if configuration.cert_file:
273+
ssl_opts['certfile'] = configuration.cert_file
274+
if configuration.key_file:
275+
ssl_opts['keyfile'] = configuration.key_file
276+
277+
websocket = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
278+
if configuration.proxy:
279+
proxy_url = urlparse(configuration.proxy)
280+
websocket.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port)
281+
else:
282+
websocket.connect(url, header=header)
283+
return websocket
284+
285+
286+
def _configuration(api_client):
287+
# old generated code's api client has config. new ones has
288+
# configuration
289+
try:
290+
return api_client.configuration
291+
except AttributeError:
292+
return api_client.config
293+
294+
295+
def websocket_call(api_client, _method, url, **kwargs):
273296
"""An internal function to be called in api-client when a websocket
274297
connection is required. args and kwargs are the parameters of
275298
apiClient.request method."""
276299

277-
url = args[1]
300+
url = get_websocket_url(url, kwargs.get("query_params"))
301+
headers = kwargs.get("headers")
278302
_request_timeout = kwargs.get("_request_timeout", 60)
279303
_preload_content = kwargs.get("_preload_content", True)
280304
capture_all = kwargs.get("capture_all", True)
281-
headers = kwargs.get("headers")
282-
283-
# Expand command parameter list to indivitual command params
284-
query_params = []
285-
for key, value in kwargs.get("query_params", {}):
286-
if key == 'command' and isinstance(value, list):
287-
for command in value:
288-
query_params.append((key, command))
289-
else:
290-
query_params.append((key, value))
291-
292-
if query_params:
293-
url += '?' + urlencode(query_params)
294305

295306
try:
296-
client = WSClient(configuration, get_websocket_url(url), headers, capture_all)
307+
client = WSClient(_configuration(api_client), url, headers, capture_all)
297308
if not _preload_content:
298309
return client
299310
client.run_forever(timeout=_request_timeout)

0 commit comments

Comments
 (0)