-
Notifications
You must be signed in to change notification settings - Fork 1
/
file_management.py
100 lines (91 loc) · 2.83 KB
/
file_management.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""File for loading and saving results and models"""
import pickle
import torch
import os
import config
folder = config.output_folder
def save_model(models,t):
"""
Function for saving the model after each training epoch
"""
current_dir = os.getcwd()
res_path = f"{current_dir}\\results"
fpath = f"{res_path}\\{folder}"
fname = f"{fpath}\\{t+1}_epochs.pth"
try:
os.mkdir(res_path)
except OSError:
pass
try:
os.mkdir(fpath)
except OSError:
pass
torch.save(config.VAD.state_dict(), fname)
def save_model_initial(models):
"""
Function for saving a model that can be used as a common initialisation
"""
current_dir = os.getcwd()
res_path = f"{current_dir}\\results"
fname = f"{res_path}\\initial4.pth"
try:
os.mkdir(res_path)
except OSError:
pass
else:
print ("Successfully created the directory %s " % res_path)
torch.save(config.VAD.state_dict(), fname)
def save_results(res,t):
"""
Function for saving the dictionary containing information on training and validation as a pickle file
"""
current_dir = os.getcwd()
res_path = f"{current_dir}\\results"
fpath = f"{res_path}\\{folder}"
fname = f"{fpath}\\epoch_{t}.pickle"
try:
os.mkdir(res_path)
except OSError:
pass
try:
os.mkdir(fpath)
except OSError:
pass
with open(fname, "wb") as fp: #Pickling
pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL)
def save_results_AUC(res, t):
"""
Function for saving the dictionary containing information on testing as a pickle file
"""
current_dir = os.getcwd()
res_path = f"{current_dir}\\results"
fpath = f"{res_path}\\{folder}"
fname = f"{fpath}\\AUC_results_set{config.dset}_epoch{t}.pickle"
try:
os.mkdir(res_path)
except OSError:
pass
try:
os.mkdir(fpath)
except OSError:
pass
with open(fname, "wb") as fp: #Pickling
pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL)
def load_results():
"""
Function for loading a saved pickle file
"""
current_dir = os.getcwd()
with open(f"{current_dir}\\results\\epoch.pickle", "rb") as input_file:
return pickle.load(input_file)
def load_model():
"""
Function for loading a saved model
"""
current_dir = os.getcwd()
fpath = f"{current_dir}"
# fname = f"{fpath}\\model.pth"
fname = f"{fpath}\\results\\batchnorm\\20_epochs.pth"
fname = r"C:\Users\claus\OneDrive - Aalborg Universitet\kode\hopefully final\smukkeficeret kode\results\batchnorm2\20_epochs.pth"
config.VAD.load_state_dict(torch.load(fname), strict=False)
return config.VAD