-
Notifications
You must be signed in to change notification settings - Fork 0
/
MainServer.py
389 lines (320 loc) · 16.6 KB
/
MainServer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
from twisted.internet.protocol import ServerFactory
from twisted.protocols.basic import LineReceiver, NetstringReceiver
from twisted.internet.endpoints import TCP4ServerEndpoint
from twisted.internet import reactor
from sys import stdout
from Team import Team
from vars import *
import json, datetime, traceback
from db import DBHandler
import team_utils as tu
import log_utils as lu
import message_templates as mt
import attributes as atts
import server_utils as su
import keys as k
# TODO in order:
# Incorporate drone COMMUNICATION
# Make sure we can setup connection with other computer.
# Consider sendind FAILED requests again after all WAITING ones have been sent.
# check if when a team submits a bid, their drone is not being considered for another bid.
class TeamServerSideProtocol(NetstringReceiver):
# -----------------TWISTED PROTOCOL METHODS--------------------------------#
def connectionMade(self):
'''
Called when a conenction is made with a team for the first time.
Update relevant variables and start authentication process.
Note: self.factory was set by the factory's default buildProtocol
'''
self.db = self.factory.db
# Remember we work with protocols. Every protocol is a connection
self.factory.protocols[self] = None
self.factory.numProtocols += 1
lu.writeToLog(self.factory,
'Connection made. There are currently %d open connections.' % (self.factory.numProtocols,))
def connectionLost(self, reason):
'''
Called when connection to team is lost. Update relevant variables and clean up data structures so team can log in again.
'''
# Log out team.
team_id = self.factory.protocols[self]
tu.logOutTeam(self.factory, team_id)
# Lower number of protocols.
self.factory.numProtocols -= 1
del self.factory.protocols[self]
lu.writeToLog(self.factory,
'Connection lost. There are currently %d open connections.' % (self.factory.numProtocols,))
def stringReceived(self, line):
'''
Called when line/message is received from team. Always check 'type' to decide what to do. Call
methods from here but implement in MainFactory.
'''
try:
# Get message and log it.
team_id = self.factory.protocols[self]
message = line.decode()
lu.writeToLogFromTeam(self.factory, message, team_id, verbose=True)
# deserialize message
message = json.loads(message)
# Check if message has 'type' attribute
if not su.hasattr(self, message, 'type'): return
# Team Authentication
if message['type'] == 'auth':
tu.processAuthentication(self,message)
return
# Team is trying to send a message without our approval or before logging in.
elif team_id is None or not self.factory.teams[team_id].isLoggedIn():
tu.writeToTeam(self, mt.PLEASE_LOGIN_MSG)
return
# Team is logged in. We can accept data.
else:
# Team is sending drone state information, accept it and store it.
if message['type'] == 'drone_state':
self.updateDroneState(team_id, message)
tu.writeToTeam(self, mt.THANKS_MSG, verbose=False)
# Team want to log out. Log them out.
elif message['type'] == 'logout':
tu.logOutTeam(self.factory, team_id)
elif message['type'] == 'bid':
self.onReceiveBid(team_id, message)
elif message['type'] == 'task_update':
self.factory.onReceiveTaskUpdate(message)
except Exception as e:
lu.writeToLog(self.factory, '[ERROR] {}'.format(str(e)), verbose=False)
print(traceback.format_exc())
def updateDroneState(self, team_id, message):
# Make sure message has all needed fields.
if not su.hasattr(self, message, 'drone_state'): return
new_drone_state = message['drone_state']
if not su.hasattrs(self, new_drone_state, atts.DRONE_STATE_ATTRS): return
# Try to insert state into DB. Write back result to team.
try:
success = self.factory.teams[team_id].upateDroneState(new_drone_state['drone_id'], new_drone_state, datetime.datetime.now())
if success:
tu.writeToTeam(self, mt.THANKS_MSG, verbose=False)
except Exception as e:
tu.writeToTeam(self, mt.ERROR_RESPONSE('DB error. Make sure all fields are correct. Error: {}'.format(str(e)))) # Send error
print(traceback.format_exc())
def onReceiveBid(self, team_id, message):
'''
Handles bid reception from team. Checks request state, stores it in the database,
updates request state. If all teams submit a bid for a request, stops timeout.
Calls function to start sending the relevant task.
'''
# Make sure message if formatted properly.
if not su.hasattr(self, message, 'bid'): return
bid = message['bid']
if not su.hasattrs(self, bid, atts.BID_ATTRS): return
# Get request that corresponds to bid from DB and handle errors.
request_id = bid['request_id']
try:
with self.db:
result = self.db.query_one('SELECT * FROM requests WHERE request_id = %s', (request_id,))
if result is not None:
request = result
else:
raise Exception('[ERROR ON RECEIVING BID] Key error [request_id = {}]'.format(request_id))
# Case when bid already timed out
if request['state'] != 'SENT' and request['state'] != 'ACCEPTED':
raise Exception("[ERROR ON RECEIVING BID] Already timeout [TEAM# {1}, REQ# {0}]".format(team_id, request_id))
# Try to put into DB and update request_states
with self.db:
bid_exists = self.db.count('Bids', 'request_id = %s AND team_id = %s', (request_id, team_id))
if bid_exists:
raise Exception('[ERROR ON RECEIVING BID] Bid already exists [TEAM# {1}, REQ# {0}]'.format(request_id, team_id))
else:
bid_accepted = True if bid['accepted'] else False
self.db.insert_values('Bids',
(DBHandler.DEFAULT, bid['price'], bid['seconds_expected'], bid['drone_id'], bid_accepted, False, team_id, request_id))
# count number of bids
n_bids = self.db.count('Bids', 'request_id = %s', (request_id,))
result = self.db.query_one('SELECT sent_to FROM Requests WHERE request_id = %s', (request_id,))
if result is not None:
n_sent = result[0]
else:
raise Exception('[ERROR ON RECEIVING BID] [TEAM# {1}, REQ# {0}]')
tu.writeToTeam(self, mt.THANKS_MSG, verbose=False)
# Check if all teams submitted bids to stop timeout and send results.
if n_bids == n_sent:
with self.db:
self.db.query_one('UPDATE Requests SET state = %s WHERE request_id = %s; ',
('ALL_ACCEPTED', request_id))
lu.writeToLog(self.factory,"Received ALL bids for REQ# {}, timeout callback canceled.".format(request_id))
self.factory.requestCallIds[request_id].cancel()
self.factory.sendBidResults(request_id)
else:
with self.db:
self.db.query_one('UPDATE Requests SET state = %s WHERE request_id = %s;',
('ACCEPTED', request_id))
except Exception as e:
lu.writeToLog(self.factory, str(e))
tu.writeToTeam(self, mt.ERROR_RESPONSE('Bidding error. Error: {}'.format(str(e))))
print(traceback.format_exc())
def onReceiveTaskUpdate(self, message):
if not su.hasattrs(self, message, atts.TASK_UPDATE_ATTS): return
status = message['status']
if status not in attts.POSSIBLE_TASK_STATUSES:
tu.writeToTeam(self, mt.ERROR_RESPONSE('Invalid status.'))
else:
tu.writeToTeam(self, mt.THANKS_MSG)
with self.factory.db:
self.factory.db.query_list('UPDATE Requests SET status = %;', (status.upper()))
if status == 'deny':
with self.factory.db:
self.factory.db.query_list('UPDATE Requests SET status = %;', (FAILED))
if status == 'pickup':
# TODO: Make sure that it is true that the team's drone is close to the pickup port.
pass
if status == 'success':
# TODO: Make sure that it is true that the team's drone is close to the pickup port.
pass
class MainFactory(ServerFactory):
def __init__(self):
'''
Initializes variables:
numProtocols: number of open protocols
protocols: maps { protocol: team_id }
teams: maps { team_id: Team object }
passwords: maps { team_id : password }
requestCallIds = { request_id : requestCallId }
'''
self.db = DBHandler(k.DB_NAME, k.DB_USER, k.DB_PASSWORD)
self.numProtocols = 0
self.protocols = {}
self.requestCallIds = {}
self._loadTeams()
def _loadTeams(self):
'''
Load team info from DB and initialize team objects.
'''
with self.db:
teams = self.db.query_list('SELECT * FROM Teams;')
self.db.query_list('UPDATE Teams SET is_logged_in = FALSE;') # reset all login states to logged out.
self.teams = {team['team_id']: Team(team['team_id'], team['password'], self.db) for team in teams}
# ---------------TWISTED FACTORY METHODS----------------------------------#
# This will be used by the default buildProtocol to create new protocols:
protocol = TeamServerSideProtocol
def startFactory(self):
'''
Called when factory starts running. Sets isRunning to true for logging. Opens files.
'''
print('[SERVER] Factory started. Waiting for connections.')
self.isRunning = True
self.log = open(LOG_NAME, 'a')
self.log.write('\nStarting new log: ' + str(datetime.datetime.now()) + '\n')
self.startNextBroadcastTimer()
def stopFactory(self):
'''
Called when factory is shutting down. Set isRunning to false so we can log without errors, and closes opened files.
'''
self.isRunning = False
self.log.write('Stopping factory. \n')
self.log.close()
# ----------------REQUEST BROADCASTING------------------------------------#
def startNextBroadcastTimer(self):
try:
# Retrieve request.
request = su.getNextRequest(self, self.db)
if request is None:
lu.writeToLog(self, 'No WAITING requests left.')
return
# Start timer (by use of callback)
time_til_broadcast = (request['time_requested'] - datetime.datetime.now()).seconds
lu.writeToLog(self, 'Time until next broadcast: {}'.format(time_til_broadcast))
reactor.callLater(time_til_broadcast, self.broadcastNextRequest, (request))
except Exception as e:
lu.writeToLog(self, '[ERROR] Failed at broadcastTimer. Error: {}'.format(str(e)))
print(traceback.format_exc())
def broadcastNextRequest(self, request):
'''
Query request from DB and set a timer to broadcast the given request at the
time specified in the request.
'''
try:
# Create request message for broadcast.
request_msg = mt.REQUEST_MSG(request)
# Send to all teams that are logged in and set timeout.
lu.writeToLog(self, 'Brodcasting REQ# {} to Teams: {}'.format(request['request_id'], ','.join([self.protocols[protocol] for protocol in self.protocols])))
for protocol in self.protocols:
tu.writeToTeam(protocol, request_msg, verbose=False)
callId = reactor.callLater(REQUEST_TIMEOUT, self.bidTimeOut, request['request_id'])
self.requestCallIds[request['request_id']] = callId
# Update request state to sent.
with self.db:
self.db.query_one('UPDATE Requests SET state = %s WHERE request_id = %s;',
('SENT', request['request_id']))
sent_to = len(self.protocols)
self.db.query_one('UPDATE Requests SET sent_to = %s WHERE request_id = %s;',
(sent_to, request['request_id']))
self.startNextBroadcastTimer()
except Exception as e:
lu.writeToLog(self, '[ERROR] Failed broadcast requests. Error: {}'.format(str(e)))
print(traceback.format_exc())
def bidTimeOut(self, request_id):
'''
Call when bid request times out.
'''
try:
lu.writeToLog(self, "[BID TIMEOUT] [REQ# {}]".format(request_id))
with self.db:
# Check if someone accepted. If not, update state.
result = self.db.query_one('SELECT state FROM Requests WHERE request_id = %s', (request_id,))
if result is not None:
state = result[0]
else:
raise Exception('Failed to query request state.')
if state == 'SENT':
self.db.query_one('UPDATE Requests SET state = %s WHERE request_id = %s;', ('NOT_ACCEPTED', request_id))
lu.writeToLog(self, '[Request NOT_ACCEPTED] [REQ# {}]'.format(request_id))
return
# If bid was accepted, collect bids.
self.sendBidResults(request_id)
except Exception as e:
lu.writeToLog(self, '[ERROR] On bid timeout. Error: {}'.format(str(e)))
print(traceback.format_exc())
# ----------------BID HANDLING--------------------------------------------#
def selectBestBid(self, bids):
# dummy function now TODO implement once we know how to calculate the best bid.
return bids[0]
def sendBidResults(self, request_id):
'''
Collects bids for REQ# request_id and figures out what is the best bid.
Then we proceed to send the best bid to the given drone/team.
If this function is called it means at least one bid was submitted.
'''
try:
# Get bids bid and request from DB.
with self.db:
bids_accepted = self.db.query_list('SELECT * FROM Bids WHERE request_id = %s AND accepted = TRUE;',
(request_id,))
if bids_accepted is None or len(bids_accepted) == 0:
raise('Failure collecting bids.')
request = self.db.query_one('SELECT * FROM Requests WHERE request_id = %s',(request_id,))
if request is None:
raise('Failure collecting bids.')
# Send result to winning bid
best_bid = self.selectBestBid(bids_accepted)
bid_winning_team = best_bid['team_id']
now = datetime.datetime.now()
time_expected = now + datetime.timedelta(seconds=best_bid['seconds_expected'])
with self.db:
self.db.query_one('UPDATE Requests SET state = %s WHERE request_id = %s;', ('ASSIGNED', request_id))
self.db.query_one('UPDATE Requests SET time_assigned = %s WHERE request_id = %s;', (now, request_id))
self.db.query_one('UPDATE Requests SET time_expected = %s WHERE request_id = %s;', (time_expected, request_id))
tu.writeToTeam(self.teams[bid_winning_team].protocol, mt.WINNING_BID_RESULT(request, best_bid, str(time_expected)))
# Send result to losing teams
bid_losing_teams = [bid['team_id'] for bid in bids_accepted if bid['team_id'] != bid_winning_team]
for blt in bid_losing_teams:
protocol = self.teams[blt].protocol
if protocol is not None:
tu.writeToTeam(self.teams[blt].protocol, mt.LOSING_BID_RESULT(request))
except Exception as e:
lu.writeToLog(self, '[ERROR] On sendBidResults. Error: {}'.format(str(e)))
print(traceback.format_exc())
# 8007 is the port you want to run under. Choose something >1024
def main():
endpoint = TCP4ServerEndpoint(reactor, 8007)
endpoint.listen(MainFactory())
reactor.run()
if __name__ == '__main__':
main()