-
Notifications
You must be signed in to change notification settings - Fork 2
/
grapher.py
360 lines (299 loc) · 15 KB
/
grapher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
from cProfile import label
from cmath import inf
import sys
from tkinter import Y
from turtle import color
from xml.etree.ElementTree import tostring
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
from matplotlib.axis import Axis
import numpy as np
import pandas as pd
import csv
import string
import traceback
import PyQt5.QtWidgets as qtw
import PyQt5.QtGui as qtg
import PyQt5.QtCore as qtc
import gait_parameters
class MplCanvas(FigureCanvasQTAgg):
def __init__(self):
self.fig, self.axes = plt.subplots(2, 1, constrained_layout = True)
super(MplCanvas, self).__init__(self.fig)
class DataDisplay(qtw.QWidget):
def __init__(self):
super().__init__()
self.mouse_hold = False #bool for mouse being currently held down
self.num_frames = 0
self.current_frame = 0 #stores current frame for vertical line plotting
self.video_duration = 1 #prevents a divide by zero error when video_position_changed initially executes
self.bodypart_list = [] #stores the names of all the unique body parts in string
self.threshold = 0 #likelihood threshold for filtering points, by default set to 0 (all points meet threshold)
#create an empty data frame using pandas API
self.data_frame = pd.DataFrame()
self.init_ui()
self.show()
def init_ui(self):
#create open-file button
self.openBtn = qtw.QPushButton('Open CSV Data')
self.openBtn.clicked.connect(self.open_file)
#create toggle points by threshold button
self.thresholdBtn = qtw.QPushButton('Set Likelihood Threshold')
self.thresholdBtn.clicked.connect(self.set_likelihood_threshold)
#create calculate a new column button
self.calcBtn = qtw.QPushButton("Calculate Gait Parameters")
self.calcBtn.clicked.connect(self.calc_gait_parameters)
#create save data button
self.saveBtn = qtw.QPushButton('Save Data')
self.saveBtn.clicked.connect(self.save_filtered_data)
# Create the maptlotlib FigureCanvas object
self.plot = MplCanvas()
self.plot.axes[0].set_title('X Coordinate by Frame')
self.plot.axes[1].set_title('Y Coordinate by Frame')
#self.plot.setMinimumSize(480, 270)
self.plot.mpl_connect('button_press_event', self.click_graph)
self.plot.mpl_connect('scroll_event', self.zoom)
self.plot.mpl_connect('button_release_event', self.release_graph)
self.plot.mpl_connect('motion_notify_event', self.move_mouse)
#create a list_widget to control plotted variables
self.list_widget = qtw.QListWidget()
self.list_widget.setMaximumWidth(200)
self.list_widget.setSelectionMode(2) #2 == MultiSelection, 3 == ExtendedSelection
self.list_widget.itemClicked.connect(self.change_plotted_data)
self.list_widget.itemSelectionChanged.connect(self.change_plotted_data)
#add widgets to layout
graphLayout = qtw.QHBoxLayout()
plotLayout = qtw.QVBoxLayout()
buttonLayout = qtw.QHBoxLayout()
plotLayout.addWidget(self.plot)
buttonLayout.addWidget(self.openBtn)
buttonLayout.addWidget(self.thresholdBtn)
buttonLayout.addWidget(self.calcBtn)
buttonLayout.addWidget(self.saveBtn)
plotLayout.addLayout(buttonLayout)
graphLayout.addLayout(plotLayout)
graphLayout.addWidget(self.list_widget)
self.setLayout(graphLayout)
#Open CSV data, Plot on Graph
def open_file(self):
filename, _ = qtw.QFileDialog.getOpenFileName(self, "Open Spreadsheet Data")
if filename:
try:
#load data into pandas dataframe
if filename.endswith(".csv"):
self.data_frame = pd.read_csv(filename)
elif filename.endswith(".xlsx"):
self.data_frame = pd.read_excel(filename)
else:
raise ValueError("Unsupported file type. Only CSV and Excel are allowed.")
#clean data by combining labels and reindexing
bodyparts_labels = self.data_frame.loc[0]
coords_labels = self.data_frame.loc[1]
labels = [i + "_" + j for i, j in zip(bodyparts_labels, coords_labels)]
self.data_frame.columns = labels
self.data_frame = self.data_frame.iloc[2: , : ]
self.data_frame.index = range(len(self.data_frame.index))
self.data_frame = self.data_frame.drop(columns=["bodyparts_coords"])
#create a list of the bodyparts add to the list widget: only add one for each triplet of x,y, likelihood
self.bodypart_list.clear()
for col in self.data_frame.columns:
#convert dtype from object to float64
self.data_frame[col] = pd.to_numeric(self.data_frame[col],errors = 'coerce')
#add column label to list_widget
bodyparts_label = str(col).split('_')[0]
if bodyparts_label not in self.bodypart_list:
self.bodypart_list.append(bodyparts_label)
item = qtw.QListWidgetItem(bodyparts_label)
self.list_widget.addItem(item)
self.num_frames = len(self.data_frame.index)
self.plot.axes[0].clear()
self.plot.axes[1].clear()
self.plot.axes[0].set_xlabel('Frame Number')
self.plot.axes[0].set_ylabel('Pixel Coordinate (X)')
self.plot.axes[1].set_xlabel('Frame Number')
self.plot.axes[1].set_ylabel('Pixel Coordinate (Y)')
self.plot.axes[0].margins(x=0, y=0)
self.plot.axes[1].margins(x=0, y=0)
self.plot.axes[0].legend()
self.plot.axes[1].legend()
self.plot.axes[0].axvline(x = 0, color = 'r', label = 'current frame')
self.plot.axes[1].axvline(x = 0, color = 'r', label = 'current frame')
self.plot.draw_idle()
except Exception as e:
show_warning_messagebox(str(e))
traceback.print_exc()
#save the data w/ current threshold to file
def save_filtered_data(self):
save_path, _ = qtw.QFileDialog.getSaveFileName(self, "Save Filtered Data Points to File", '', '*.csv')
data = []
columns = []
for body_part in self.bodypart_list:
x_data = self.data_frame.loc[:, body_part + "_x"]
y_data = self.data_frame.loc[:, body_part + "_y"]
likelihood_data = self.data_frame.loc[:, body_part + "_likelihood"]
x_data = x_data.to_numpy()
y_data = y_data.to_numpy()
likelihood_data = likelihood_data.to_numpy()
x_data = np.ma.masked_where(likelihood_data < self.threshold, x_data)
y_data = np.ma.masked_where(likelihood_data < self.threshold, y_data)
x_data = np.ma.filled(x_data, np.nan)
y_data = np.ma.filled(x_data, np.nan)
columns.append(body_part+"_x")
columns.append(body_part+"_y")
columns.append(body_part+"_likelihood")
data.append(x_data)
data.append(y_data)
data.append(likelihood_data)
data = np.swapaxes(data, 0, 1)
df = pd.DataFrame(data=data, columns=columns)
df.to_csv(save_path)
#switch the data plotted on the graph
def change_plotted_data(self):
while self.plot.axes[0].lines:
self.plot.axes[0].lines[-1].remove()
while self.plot.axes[1].lines:
self.plot.axes[1].lines[-1].remove()
items = self.list_widget.selectedItems()
#grab max/min y to set plot bounds
minx, maxx, miny, maxy = 0, 1, 0, 1 #initialized for case of empty plot
if items:
minx = inf
maxx = -inf
miny = inf
maxy = -inf
for i in items:
#update this function to incorporate threshold member variable when plotting
x_data = self.data_frame.loc[:, i.text() + "_x"]
y_data = self.data_frame.loc[:, i.text() + "_y"]
likelihood_data = self.data_frame.loc[:, i.text() + "_likelihood"]
x_data = x_data.to_numpy()
y_data = y_data.to_numpy()
likelihood_data = likelihood_data.to_numpy()
x_data = np.ma.masked_where(likelihood_data < self.threshold, x_data)
y_data = np.ma.masked_where(likelihood_data < self.threshold, y_data)
#cmap = matplotlib.colormaps['plasma']
#colored = [cmap(tl) for tl in likelihood_data]
self.plot.axes[0].plot(range(len(x_data)), x_data, label = i.text(), marker='.')
self.plot.axes[1].plot(range(len(y_data)), y_data, label = i.text(), marker='.')
#catch runtime warning when all nans (from threshold == 1)
if (minx > np.nanmin(x_data)):
minx = np.nanmin(x_data)
if (maxx < np.nanmax(x_data)):
maxx = np.nanmax(x_data)
if (miny > np.nanmin(y_data)):
miny = np.nanmin(y_data)
if (maxy < np.nanmax(y_data)):
maxy = np.nanmax(y_data)
self.plot.axes[0].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.axes[1].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.axes[0].legend()
self.plot.axes[1].legend()
#reset vertical axis range
dx = (maxx - minx)*0.1
dy = (maxy - miny)*0.1
try:
self.plot.axes[0].set_ylim(minx-dx, maxx+dx)
self.plot.axes[1].set_ylim(miny-dy, maxy+dy)
except ValueError as e:
self.plot.axes[0].set_ylim(0, 1)
self.plot.axes[0].set_ylim(0, 1)
self.plot.draw_idle()
#=======GRAPH INTERACTIVITY========
def click_graph(self, event):
self.mouse_hold = True
if all(event.inaxes != ax for ax in self.plot.axes): return
if self.plot.axes[0].lines and self.plot.axes[1].lines:
self.plot.axes[0].lines[-1].remove()
self.plot.axes[1].lines[-1].remove()
self.current_frame = int(event.xdata)
self.plot.axes[0].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.axes[1].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.draw_idle()
#print("grapher_click_graph")
def release_graph(self, event):
self.mouse_hold = False
def move_mouse(self, event):
if self.mouse_hold:
if all(event.inaxes != ax for ax in self.plot.axes): return
if self.plot.axes[0].lines and self.plot.axes[1].lines:
self.plot.axes[0].lines[-1].remove()
self.plot.axes[1].lines[-1].remove()
self.current_frame = int(event.xdata)
self.plot.axes[0].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.axes[1].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.draw_idle()
#print("grapher_click_graph")
def zoom(self, event):
cur_xlim = self.plot.axes[0].get_xlim()
#cur_ylim = self.graph.axes.get_ylim()
cur_xrange = (cur_xlim[1] - cur_xlim[0])*.5
#cur_yrange = (cur_ylim[1] - cur_ylim[0])*.5
xdata = event.xdata # get event x location
#ydata = event.ydata # get event y location
if event.button == 'up':
# deal with zoom in
scale_factor = 1/1.5 # <-------- change this to change magnitude of zoom
elif event.button == 'down':
# deal with zoom out
scale_factor = 1.5 # <----------
else:
# deal with something that should never happen
scale_factor = 1
# set new limits
xmin = xdata - cur_xrange*scale_factor
xmax = xdata + cur_xrange*scale_factor
if xmin < 0: xmin = 0
if xmax > self.num_frames: xmax = self.num_frames
self.plot.axes[0].set_xlim([xmin, xmax])
self.plot.axes[1].set_xlim([xmin, xmax])
#self.graph.axes.set_ylim([ydata - cur_yrange*scale_factor,
#ydata + cur_yrange*scale_factor])
self.plot.draw_idle()
#========DATA MANIPULATION===========
def set_likelihood_threshold(self):
self.threshold, done = qtw.QInputDialog.getDouble(self,
"Threshold Dialog",
"Enter a likelihood value between 0-1. Graph will only display points above this threshold.",
value=0, min=0, max=1, decimals=3)
self.change_plotted_data()
def calc_gait_parameters(self):
items = [self.list_widget.item(i).text() for i in range(self.list_widget.count())]
if items:
dialog = gait_parameters.ParameterInputDialog(items, self.data_frame)
if dialog.exec_() == qtw.QDialog.Accepted:
print("Landmarks:", dialog.confirmed_landmarks)
print("Gait Parameters:", dialog.queried_gait_parameters)
print("Summary Statistics:", dialog.summ_stats)
else:
qtw.QMessageBox.warning(self, "No landmarks are available! Try loading a data file.")
return
#========VIDEO FUNCTIONALITY=========
#Slide a vertical line along the graph as the video frame changes
def video_position_changed(self, position):
#convert from a video position in milliseconds to a frame number
frame = round(position)
self.current_frame = frame
if self.plot.axes[0].lines and self.plot.axes[1].lines:
self.plot.axes[0].lines[-1].remove()
self.plot.axes[1].lines[-1].remove()
self.plot.axes[0].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.axes[1].axvline(x = self.current_frame, color = 'r', label = 'current frame')
self.plot.draw_idle()
#Store the duration of video in graph object, supports vertical line scrubbing function.
def video_duration_changed(self, duration):
self.video_duration = duration
#end of class==============
def show_warning_messagebox(message):
msg = qtw.QMessageBox()
msg.setIcon(qtw.QMessageBox.Warning)
# setting message for Message Box
msg.setText(message)
# setting Message box window title
msg.setWindowTitle("Warning")
# declaring buttons on Message Box
msg.setStandardButtons(qtw.QMessageBox.Ok)
# start the app
retval = msg.exec_()