-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
28 lines (24 loc) · 997 Bytes
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import matplotlib.pyplot as plt
def parse_accuracy_data(file_path):
accuracies = []
with open(file_path, 'r') as file:
for line in file:
if "Accuracy" in line:
# Extract accuracy from the line
parts = line.split(',')
accuracy_str = parts[-1].strip().split(' ')[-1].replace('%', '')
accuracies.append(float(accuracy_str) / 100.0)
return accuracies
def plot_accuracies(*accuracy_lists):
for accuracies in accuracy_lists:
plt.plot(accuracies)
plt.title('Epoch-wise Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['withoutPruning', 'withPruning', 'worstPruning' ])
plt.show()
# Paths to the log files
file_paths = [ 'training_logwithoutPruning.txt', 'training_logwithPruning.txt', 'training_logworstPruning.txt']
# Parse the data and plot it
accuracy_data_sets = [parse_accuracy_data(path) for path in file_paths]
plot_accuracies(*accuracy_data_sets)