-
Notifications
You must be signed in to change notification settings - Fork 601
/
client.py
executable file
·386 lines (308 loc) · 11.4 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
from re import compile
import time
import sys
import requests
from .rest import RequestsNetworkWrapper, ServiceClient
from .rest_client import RESTServiceClient, Endpoint, AliasEndpoint
from .custom_query_object import CustomQueryObject
import os
import logging
logger = logging.getLogger(__name__)
_name_checker = compile(r"^[\w -]+$")
def _check_endpoint_type(name):
if not isinstance(name, str):
raise TypeError("Endpoint name must be a string")
if name == "":
raise ValueError("Endpoint name cannot be empty")
def _check_hostname(name):
_check_endpoint_type(name)
hostname_checker = compile(r"^http(s)?://[a-zA-Z0-9-_\.]+(/)?(:[0-9]+)?(/)?$")
if not hostname_checker.match(name):
raise ValueError(
f"endpoint name {name} should be in http(s)://<hostname>"
"[:<port>] and hostname may consist only of: "
"a-z, A-Z, 0-9, underscore and hyphens."
)
def _check_endpoint_name(name):
"""Checks that the endpoint name is valid by comparing it with an RE and
checking that it is not reserved."""
_check_endpoint_type(name)
if not _name_checker.match(name):
raise ValueError(
f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
" underscore, hyphens and spaces."
)
class Client:
def __init__(self, endpoint, query_timeout=1000):
"""
Connects to a running server.
The class constructor takes a server address which is then used to
connect for all subsequent member APIs.
Parameters
----------
endpoint : str, optional
The server URL.
query_timeout : float, optional
The timeout for query operations.
"""
_check_hostname(endpoint)
self._endpoint = endpoint
session = requests.session()
session.verify = False
requests.packages.urllib3.disable_warnings()
# Setup the communications layer.
network_wrapper = RequestsNetworkWrapper(session)
service_client = ServiceClient(self._endpoint, network_wrapper)
self._service = RESTServiceClient(service_client)
if query_timeout is not None and query_timeout > 0:
self.query_timeout = query_timeout
else:
self.query_timeout = 0.0
def __repr__(self):
return (
"<"
+ self.__class__.__name__
+ " object at "
+ hex(id(self))
+ " connected to "
+ repr(self._endpoint)
+ ">"
)
def get_status(self):
"""
Gets the status of the deployed endpoints.
Returns
-------
dict
Keys are endpoints and values are dicts describing the state of
the endpoint.
Examples
--------
.. sourcecode:: python
{
u'foo': {
u'status': u'LoadFailed',
u'last_error': u'error mesasge',
u'version': 1,
u'type': u'model',
},
}
"""
return self._service.get_status()
#
# Query
#
@property
def query_timeout(self):
"""The timeout for queries in seconds."""
return self._service.query_timeout
@query_timeout.setter
def query_timeout(self, value):
self._service.query_timeout = value
def query(self, name, *args, **kwargs):
"""Query an endpoint.
Parameters
----------
name : str
The name of the endpoint.
*args : list of anything
Ordered parameters to the endpoint.
**kwargs : dict of anything
Named parameters to the endpoint.
Returns
-------
dict
Keys are:
model: the name of the endpoint
version: the version used.
response: the response to the query.
uuid : a unique id for the request.
"""
return self._service.query(name, *args, **kwargs)
#
# Endpoints
#
def get_endpoints(self, type=None):
"""Returns all deployed endpoints.
Examples
--------
.. sourcecode:: python
{"clustering":
{"description": "",
"docstring": "-- no docstring found in query function --",
"creation_time": 1469511182,
"version": 1,
"dependencies": [],
"last_modified_time": 1469511182,
"type": "model",
"target": null},
"add": {
"description": "",
"docstring": "-- no docstring found in query function --",
"creation_time": 1469505967,
"version": 1,
"dependencies": [],
"last_modified_time": 1469505967,
"type": "model",
"target": null}
}
"""
return self._service.get_endpoints(type)
def _get_endpoint_upload_destination(self):
"""Returns the endpoint upload destination."""
return self._service.get_endpoint_upload_destination()["path"]
def deploy(self, name, obj, description="", schema=None, override=False):
"""Deploys a Python function as an endpoint in the server.
Parameters
----------
name : str
A unique identifier for the endpoint.
obj : function
Refers to a user-defined function with any signature. However both
input and output of the function need to be JSON serializable.
description : str, optional
The description for the endpoint. This string will be returned by
the ``endpoints`` API.
schema : dict, optional
The schema of the function, containing information about input and
output parameters, and respective examples. Providing a schema for
a deployed function lets other users of the service discover how to
use it. Refer to schema.generate_schema for more information on
how to generate the schema.
override : bool
Whether to override (update) an existing endpoint. If False and
there is already an endpoint with that name, it will raise a
RuntimeError. If True and there is already an endpoint with that
name, it will deploy a new version on top of it.
See Also
--------
remove, get_endpoints
"""
endpoint = self.get_endpoints().get(name)
if endpoint:
if not override:
raise RuntimeError(
f"An endpoint with that name ({name}) already"
' exists. Use "override = True" to force update '
"an existing endpoint."
)
version = endpoint.version + 1
else:
version = 1
obj = self._gen_endpoint(name, obj, description, version, schema)
self._upload_endpoint(obj)
if version == 1:
self._service.add_endpoint(Endpoint(**obj))
else:
self._service.set_endpoint(Endpoint(**obj))
self._wait_for_endpoint_deployment(obj["name"], obj["version"])
def _gen_endpoint(self, name, obj, description, version=1, schema=[]):
"""Generates an endpoint dict.
Parameters
----------
name : str
Endpoint name to add or update
obj : func
Object that backs the endpoint. See add() for a complete
description.
description : str
Description of the endpoint
version : int
The version. Defaults to 1.
Returns
-------
dict
Keys:
name : str
The name provided.
version : int
The version provided.
description : str
The provided description.
type : str
The type of the endpoint.
endpoint_obj : object
The wrapper around the obj provided that can be used to
generate the code and dependencies for the endpoint.
Raises
------
TypeError
When obj is not one of the expected types.
"""
# check for invalid PO names
_check_endpoint_name(name)
if description is None:
if isinstance(obj.__doc__, str):
# extract doc string
description = obj.__doc__.strip() or ""
else:
description = ""
endpoint_object = CustomQueryObject(query=obj, description=description,)
return {
"name": name,
"version": version,
"description": description,
"type": "model",
"endpoint_obj": endpoint_object,
"dependencies": endpoint_object.get_dependencies(),
"methods": endpoint_object.get_methods(),
"required_files": [],
"required_packages": [],
"schema": schema,
}
def _upload_endpoint(self, obj):
"""Sends the endpoint across the wire."""
endpoint_obj = obj["endpoint_obj"]
dest_path = self._get_endpoint_upload_destination()
# Upload the endpoint
obj["src_path"] = os.path.join(
dest_path, "endpoints", obj["name"], str(obj["version"])
)
endpoint_obj.save(obj["src_path"])
def _wait_for_endpoint_deployment(
self, endpoint_name, version=1, interval=1.0,
):
"""
Waits for the endpoint to be deployed by calling get_status() and
checking the versions deployed of the endpoint against the expected
version. If all the versions are equal to or greater than the version
expected, then it will return. Uses time.sleep().
"""
logger.info(
f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
)
start = time.time()
while True:
ep_status = self.get_status()
try:
ep = ep_status[endpoint_name]
except KeyError:
logger.info(
f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
)
else:
logger.info(f"ep={ep}")
if ep["status"] == "LoadFailed":
raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
elif ep["status"] == "LoadSuccessful":
if ep["version"] >= version:
logger.info("LoadSuccessful")
break
else:
logger.info("LoadSuccessful but wrong version")
if time.time() - start > 10:
raise RuntimeError("Waited more then 10s for deployment")
logger.info(f"Sleeping {interval}...")
time.sleep(interval)
def set_credentials(self, username, password):
"""
Set credentials for all the TabPy client-server communication
where client is tabpy-tools and server is tabpy-server.
Parameters
----------
username : str
User name (login). Username is case insensitive.
password : str
Password in plain text.
"""
self._service.set_credentials(username, password)