-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
69 lines (65 loc) · 3.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from config import data_info, algos_path, default_it_max, norm_type
from utils import run_gd, run_gd_plus, run_gd_plus_fnl, run_diana, run_diana_plus, run_diana_vnl, run_diana_plus_fnl, \
run_gd_vnl, run_sparse_diana_plus, run_block_diana, run_block_diana_plus, run_block_dcgd, run_block_dcgd_plus, \
run_block_diana_plus_fnl, run_diag_diana_plus, run_diag_block_diana_plus, run_diag_block_diana_plus_safe
import argparse
parser = argparse.ArgumentParser(description='Run the Algorithms')
parser.add_argument('--data', action='store', dest='dataset', type=str, help='Dataset')
parser.add_argument('--alg', action='store', dest='algo_name', type=str,
help='Which algorithm: SD-DCGD')
parser.add_argument('--n', action='store', dest='n_workers', type=int, help='Number of workers')
args = parser.parse_args()
dataset = args.dataset
algo_name = args.algo_name
n_workers = args.n_workers
N, d = data_info[dataset]
if algo_name == 'SD-DCGD':
it_max = int(default_it_max)
run_gd(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'SD-DCGD-plus':
it_max = int(default_it_max)
run_gd_plus(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'SD-DCGD-plus-fnl':
it_max = int(default_it_max)
run_gd_plus_fnl(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'SD-DCGD-vnl':
it_max = int(default_it_max)
run_gd_vnl(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'SD-DIANA':
it_max = int(default_it_max)
run_diana(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'SD-DIANA-plus':
it_max = int(default_it_max)
run_diana_plus(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'SD-DIANA-plus-fnl':
it_max = int(default_it_max)
run_diana_plus_fnl(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'BL-DIANA':
it_max = int(default_it_max)
run_block_diana(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'BL-DCGD':
it_max = int(default_it_max)
run_block_dcgd(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'BL-DCGD-plus':
it_max = int(default_it_max)
run_block_dcgd_plus(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'BL-DIANA-plus':
it_max = int(default_it_max)
run_block_diana_plus(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'BL-DIANA-plus-fnl':
it_max = int(default_it_max)
run_block_diana_plus_fnl(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'diag-SD-DIANA-plus':
it_max = int(default_it_max)
run_diag_diana_plus(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'diag-SD-DIANA-plus-safe':
it_max = int(default_it_max)
run_diag_block_diana_plus_safe(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'diag-BL-DIANA-plus':
it_max = int(default_it_max)
run_diag_block_diana_plus(algos_path, n_workers, it_max, norm_type, dataset)
elif algo_name == 'Sparse-DIANA-plus':
it_max = int(default_it_max)
run_sparse_diana_plus(algos_path, n_workers, it_max, dataset)
else:
raise ValueError('The algorithm has not been implemented!')