This repository has been archived by the owner on Nov 11, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run-ws.py
139 lines (103 loc) · 3.78 KB
/
run-ws.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import json
import socket
import signal
import asyncio
import logging
import argparse
import websockets
from jsonschema import validate, ValidationError, SchemaError
from chainlink import Chainlink
import grader.api_keys as api_keys
from grader.definitions import GRADING_JOB_DEF
from config import *
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
async def exec_job(job):
job_id = job[api_keys.GRADING_JOB_ID]
stages = job[api_keys.STAGES]
logger.info("starting job {}".format(job_id))
# execute job
try:
chain = Chainlink(stages, workdir=os.getcwd())
job_results = await chain.run_async({})
except Exception as ex:
logger.critical("grading job failed with exception:\n{}".format(ex))
job_results = [
{
"logs": {
"stdout": b"the container crashed",
"stderr": bytes(str(ex), "utf-8"),
},
"success": False,
}
]
job_stdout = "\n".join([r["logs"]["stdout"].decode("utf-8") for r in job_results])
job_stderr = "\n".join([r["logs"]["stderr"].decode("utf-8") for r in job_results])
for r in job_results:
del r["logs"]
logger.info("finished job {}".format(job_id))
if VERBOSE:
logger.info("job stdout:\n" + job_stdout)
logger.info("job stderr:\n" + job_stderr)
return {
api_keys.RESULTS: job_results,
api_keys.SUCCESS: job_results[-1]["success"],
api_keys.LOGS: {"stdout": job_stdout, "stderr": job_stderr},
api_keys.GRADING_JOB_ID: job_id,
}
async def run(token, worker_id):
url = "{}://{}:{}{}{}/{}".format(
"wss" if USE_SSL else "ws",
API_HOSTNAME,
API_PORT,
API_PROXY,
WORKER_WS_ENDPOINT,
worker_id,
)
headers = {api_keys.AUTH: "Bearer {}".format(token)}
hostname = socket.gethostname()
async with websockets.connect(
url, ping_interval=HEARTBEAT_INTERVAL, extra_headers=headers
) as ws:
# poll job
try:
await ws.send(
json.dumps({"type": "register", "args": {"hostname": hostname}})
)
ack = json.loads(await ws.recv())
if not ack["success"]:
raise Exception("failed to register")
logger.info("registered as {}".format(worker_id))
while True:
job = json.loads(await ws.recv())
validate(instance=job, schema=GRADING_JOB_DEF)
job_result = await exec_job(job)
await ws.send(json.dumps({"type": "job_result", "args": job_result}))
except websockets.ConnectionClosed as e:
logger.critical("connection closed: {}".format(repr(e)))
except ValidationError as e:
logger.critical("validation error: {}".format(repr(e)))
except SchemaError as e:
logger.critical("schema error: {}".format(repr(e)))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("token", help="Broadway cluster token")
parser.add_argument(
"worker_id", metavar="worker-id", help="Unique worker id for registration"
)
return parser.parse_args()
def shutdown(sig, task):
logger.info("signal received: {}, shutting down".format(signal.Signals(sig).name))
task.cancel()
if __name__ == "__main__":
args = parse_args()
loop = asyncio.get_event_loop()
task = loop.create_task(run(args.token, args.worker_id))
loop.add_signal_handler(signal.SIGINT, lambda: shutdown(signal.SIGINT, task))
try:
loop.run_until_complete(task)
except asyncio.CancelledError:
logger.info("task cancelled")
except Exception as e:
logger.critical("unexpected error: {}".format(repr(e)))