-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
83 lines (64 loc) · 2.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from CounterModels.DreaMR.run_counter_dreamr import train_counter_dreamr, gen_counter_dreamr, eval_counter_dreamr
from Classifiers.BolT.run_classifier_bolt import run_bolT
import argparse
from Utils.gpuChecker import getAvailableGpus
from datetime import datetime
from Utils.utils import Option
parser = argparse.ArgumentParser()
parser.add_argument("--targetDataset", type=str, default="dummy")
parser.add_argument("--method", type=str, default="dreamr")
parser.add_argument("--loadThreshold", type=str, default=0.5)
parser.add_argument("--do", type=str, default="train")
parser.add_argument("--fromExists", type=int, default=0)
parser.add_argument("--isVal", type=int, default=0)
argv = parser.parse_args()
availableGpus = []
while (len(availableGpus) == 0):
availableGpus = getAvailableGpus(float(argv.loadThreshold))
gpu = availableGpus[0] # deneme
device = "cuda:{}".format(gpu)
foldCount = 5 # ignore this
nOfClasses = 2
dynamicLength = 128
datePrepend = "{}".format(datetime.today().strftime('%Y-%m-%d-%H:%M:%S'))
targetClassifierPath = "your/target/classifier/path/model.pt"
targetRunFolder = None # targetRunFolder: "your/target/run/folder"
targetGenFolders = None
# targetGenFolders : ["your/target/gen/folder1", "your/target/gen/folder2"]
details = Option({
"device": device,
"foldCount": foldCount,
"datePrepend": datePrepend,
"targetDataset": argv.targetDataset,
"classifierPath": targetClassifierPath,
"fromExists": argv.fromExists,
"nOfClasses": nOfClasses,
"dynamicLength": dynamicLength,
"targetRunFolder": None,
"targetGenFolders": None,
"methodName": argv.method,
"isVal": argv.isVal
})
# classifiers
# counterfactual models
trainers = {
# classifier trainers
"bolT_classify": run_bolT,
# counter trainers
"dreamr": train_counter_dreamr,
}
generators = {
"defacto": gen_counter_dreamr,
}
evaluaters = {
"defacto": eval_counter_dreamr,
}
if ("train" in argv.do):
trainer = trainers[argv.method]
trainer(details)
if ("gen" in argv.do):
generator = generators[argv.method]
generator(details)
if ("eval" in argv.do):
evaluater = evaluaters[argv.method]
evaluater(details)