-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
157 lines (125 loc) · 4.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import cv2
import torch
import numpy as np
import time
import threading
import queue
from datetime import datetime
import random
from collections import deque
import socket
from controller.DroneNavigation import DroneNavigation
from tello_models.Q import QLearningAgent
from tello_models.Rewards import RewardCalculator
from predictions.PredictionProcessor import PredictionProcessor
from video.Video import VideoProcessor
cv2.namedWindow("Tello Video Stream", cv2.WINDOW_NORMAL)
# Initialize devices
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")
midas_device = torch.device("mps") if torch.backends.mps.is_available() else device
# Initialize models
yolo_model = torch.hub.load('/Users/josephsketl/yolov5', 'yolov5s', source='local', pretrained=True)
yolo_model.conf = 0.2 # Display objects with confidence > 0.4
midas_model = torch.hub.load("/Users/josephsketl/MiDaS/", "MiDaS_small", source="local")
midas_transforms = torch.hub.load("/Users/josephsketl/MiDaS/", "transforms", source="local").small_transform
# Tello IP and Ports
TELLO_IP = '192.168.10.1'
CMD_PORT = 8889
VIDEO_PORT = 11111
STATE_PORT = 8890
TELLO_ADDRESS = (TELLO_IP, CMD_PORT)
# Set up UDP socket for receiving state information
state_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
state_socket.bind(('', STATE_PORT))
frame_queue = queue.Queue(maxsize=20) # Limit queue size for memory efficiency
# Set target labels
target_labels = ["vase", "apple", "sports ball", 'person']
# Initialize modules
prediction_processor = PredictionProcessor(
yolo_model=yolo_model,
midas_model=midas_model,
midas_transforms=midas_transforms,
target_labels=target_labels,
device=device,
midas_device=midas_device,
)
#video_processor = VideoProcessor(tello_ip="192.168.10.1", video_port=11111, prediction_processor=prediction_processor)
q_agent = QLearningAgent(actions=["forward", "left", "right", "cw"])
reward_calculator = RewardCalculator(target_color="blue")
drone = DroneNavigation(tello_ip="192.168.10.1", cmd_port=8889)
# Start threads
def start_video_threads():
capture_thread = threading.Thread(target=video_processor.capture_video, daemon=True)
capture_thread.start()
return capture_thread
# Start threads
def start_prediction_thread(prediction_processor, frame_queue):
prediction_thread = threading.Thread(
target=prediction_processor.process_predictions,
args=(frame_queue,),
daemon=True # Ensures thread exits when the main program ends
)
prediction_thread.start()
return prediction_thread
# Main loop
def main_loop():
try:
time.sleep(2)
state = 0
missing_frames_counter = 0
while True:
# Get the latest prediction results
prediction_results = prediction_processor.get_results()
detected_objects = prediction_results["detected_objects"]
if detected_objects:
missing_frames_counter = 0 # Reset counter if objects are detected
else:
missing_frames_counter += 1
# Calculate reward
reward = reward_calculator.calculate_reward(
detected_objects,
missing_frames_counter,
missing_frames_tolerance=1000
)
# Choose and execute action
action_index = q_agent.choose_action(state)
action_name = q_agent.actions[action_index]
#drone.execute_action(action_name)
# Update Q-table
next_state = (state + 1) % 100
q_agent.update_q_table(state, action_index, reward, next_state)
time.sleep(3)
# Transition to the next state
state = next_state
except KeyboardInterrupt:
print("Landing the drone...")
drone.send_command("land")
finally:
q_agent.save_q_table()
video_processor.stop()
drone.send_command("land")
drone.cmd_socket.close()
if __name__ == "__main__":
drone.send_command("command")
drone.send_command("streamon")
# Initialize VideoProcessor
video_processor = VideoProcessor(
tello_ip="192.168.10.1",
video_port=11111,
prediction_processor=prediction_processor,
)
# Start threads
capture_thread = threading.Thread(target=video_processor.capture_video, daemon=True)
process_thread = threading.Thread(target=video_processor.process_frames, daemon=True)
capture_thread.start()
process_thread.start()
main_loop_thread = threading.Thread(target=main_loop, daemon=True)
main_loop_thread.start()
# Run the PyGame-based display in the main thread
try:
video_processor.display_video_with_pygame()
except KeyboardInterrupt:
video_processor.stop()
capture_thread.join()
process_thread.join()