-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_bags.py
151 lines (131 loc) · 8.12 KB
/
train_bags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
'''
The results of PREDICTD are improved by averaging the imputed data from multiple models to produce a consensus imputed data track. This script will loop through a list of validation sets and train a model for each one so that the imputation results from multiple models can be averaged for a particular test set.
'''
import argparse
import copy
import numpy
import os
import subprocess
import sys
sys.path.append(os.path.dirname(__file__))
import s3_library
#import azure_library
#import impute_roadmap_consolidated_data as spark_model
import train_model as spark_model
spark_model.pl.s3_library = s3_library
#spark_model.pl.azure_library = azure_library
if __name__ == "__main__":
parser = spark_model.parser
parser.add_argument('--num_folds', type=int, default=2,
help='The number of validation folds for '
'which a model should be trained. This '
'should be between 2 and the number of '
'available validation folds, inclusive. '
'[default: %(default)s]')
args = parser.parse_args()
out_root = args.out_root
num_folds = args.num_folds
rs = numpy.random.RandomState(seed=args.factor_init_seed)
cmd_line_path = os.path.join(out_root, 'command_line.txt')
cmd_line_txt = ' '.join(sys.argv) + '\n'
if args.data_url.startswith('s3'):
bucket = s3_library.S3.get_bucket(args.run_bucket)
key = bucket.new_key(cmd_line_path)
key.set_contents_from_string(cmd_line_txt)
spark_model.pl.STORAGE = 'S3'
# elif args.data_url.startswith('wasb'):
# azure_library.load_blob_from_text(args.run_bucket, cmd_line_path, cmd_line_txt)
# spark_model.pl.STORAGE = 'BLOB'
else:
raise Exception('Unrecognized URL prefix on data url: {!s}'.format(args.data_url))
active_fold_path = os.path.join(out_root, 'active_fold_idx.pickle')
try:
if spark_model.pl.STORAGE == 'S3':
# range_start = s3_library.get_pickle_s3(args.run_bucket, active_fold_path)
## elif spark_model.pl.STORAGE == 'BLOB':
## range_start = azure_library.get_blob_pickle(args.run_bucket, active_fold_path)
range_start = len(s3_library.glob_keys(args.run_bucket,
os.path.join(out_root, 'valid_fold*/ct_factors.pickle')))
except:
range_start = 0
rs2 = numpy.random.RandomState(seed=args.factor_init_seed)
for iter_idx in range(range_start, num_folds):
sc = spark_model.SparkContext(appName='avg_valid_folds',
pyFiles=[os.path.join(os.path.dirname(__file__), 's3_library.py'),
#os.path.join(os.path.dirname(__file__), 'impute_roadmap_consolidated_data.py'),
os.path.join(os.path.dirname(__file__), 'train_model.py'),
os.path.join(os.path.dirname(__file__), 'predictd_lib.py')])
spark_model.pl.sc = sc
if spark_model.pl.STORAGE == 'S3':
s3_library.set_pickle_s3(args.run_bucket, active_fold_path, iter_idx)
# elif spark_model.pl.STORAGE == 'BLOB':
# azure_library.load_blob_pickle(args.run_bucket, active_fold_path, iter_idx)
idx_args = copy.deepcopy(args)
valid_cmd_line_path = ''
while not valid_cmd_line_path or s3_library.S3.get_bucket(args.run_bucket).get_key(valid_cmd_line_path):
valid_idx = rs2.randint(8)
idx_args.out_root = os.path.join(out_root, 'valid_fold{!s}'.format(valid_idx))
valid_cmd_line_path = os.path.join(idx_args.out_root, 'command_line.txt')
valid_cmd_line_txt = cmd_line_txt.strip().replace(os.path.basename(__file__).replace('.pyc', '.py'), os.path.basename(spark_model.__file__).replace('.pyc', '.py')).replace(out_root, idx_args.out_root)
idx_args.valid_fold_idx = valid_idx
valid_cmd_line_txt += ' --valid_fold_idx={!s}'.format(idx_args.valid_fold_idx)
idx_args.factor_init_seed = rs.randint(int(1e6))
valid_cmd_line_txt += ' --factor_init_seed={!s}'.format(idx_args.factor_init_seed)
# idx_args.data_iteration_seed = rs.randint(int(1e6))
# valid_cmd_line_txt += ' --data_iteration_seed={!s}'.format(idx_args.data_iteration_seed)
# idx_args.random_loci_fraction_seed = rs.randint(int(1e6))
# valid_cmd_line_txt += ' --random_loci_fraction_seed={!s}'.format(idx_args.random_loci_fraction_seed)
# idx_args.train_on_subset_seed = rs.randint(int(1e6))
# valid_cmd_line_txt += ' --train_on_subset_seed={!s}'.format(idx_args.train_on_subset_seed)
if iter_idx < range_start:
continue
if idx_args.data_url.startswith('s3'):
bucket = s3_library.S3.get_bucket(idx_args.run_bucket)
key = bucket.new_key(valid_cmd_line_path)
key.set_contents_from_string(valid_cmd_line_txt)
# elif idx_args.data_url.startswith('wasb'):
# azure_library.load_blob_from_text(idx_args.run_bucket, valid_cmd_line_path, valid_cmd_line_txt)
else:
raise Exception('Unrecognized URL prefix on data url: {!s}'.format(args.data_url))
#if this is a restart, make sure we don't just restart again in the next iteration
if iter_idx == range_start and args.restart is True:
args.restart = False
args.checkpoint = None
# #don't use the same pctl_res for the next fold because the training set changes.
# if args.pctl_res is not None:
# args.pctl_res = None
spark_model.train_consolidated(idx_args)
sc.stop()
imp_result_glob = os.path.join(out_root, '*/hyperparameters.pickle')
if spark_model.pl.STORAGE == 'S3':
glob_result1 = s3_library.glob_keys(args.run_bucket, imp_result_glob)
# glob_result2 = s3_library.glob_keys(args.run_bucket, os.path.join(os.path.dirname(imp_result_glob), 'num_parts.pickle'))
# elif spark_model.pl.STORAGE == 'BLOB':
# glob_result1 = azure_library.glob_blobs(args.run_bucket, imp_result_glob)
# glob_result2 = azure_library.glob_blobs(args.run_bucket, os.path.join(os.path.dirname(imp_result_glob), 'num_parts.pickle'))
imp_result_paths = ['s3://{!s}/{!s}'.format(args.run_bucket, key_path)
for key_path in set([os.path.dirname(elt.name) for elt in glob_result1])]
# num_folds = 4
if len(imp_result_paths) != num_folds:
raise Exception('Only {!s} of {!s} validation folds calculated.'.format(len(imp_result_paths), num_folds))
sc = spark_model.SparkContext(appName='avg_valid_folds',
pyFiles=[os.path.join(os.path.dirname(__file__), 's3_library.py'),
# os.path.join(os.path.dirname(__file__), 'azure_library.py'),
# os.path.join(os.path.dirname(__file__), 'impute_roadmap_consolidated_data.py'),
os.path.join(os.path.dirname(__file__), 'train_model.py'),
os.path.join(os.path.dirname(__file__), 'predictd_lib.py')])
spark_model.pl.sc = sc
if spark_model.pl.STORAGE == 'S3':
# storage_url_fmt = 's3n://{!s}/{!s}'
storage_url_fmt = 's3://{!s}/{!s}'
# elif spark_model.pl.STORAGE == 'BLOB':
# storage_url_fmt = 'wasbs://{!s}@imputationstoretim.blob.core.windows.net/{!s}'
# to_join = [sc.pickleFile(storage_url_fmt.format(args.run_bucket, elt)) for elt in imp_result_paths]
avg_imp = spark_model.pl.impute_and_avg(imp_result_paths, coords='test').persist()
avg_imp.count()
out_url = storage_url_fmt.format(args.run_bucket, os.path.join(out_root, '3D_svd_imputed.avg.test_set.rdd.pickle'))
spark_model.pl.save_rdd_as_pickle(avg_imp, out_url)
sc.stop()
#transform averaged models
# cmd = ['spark-submit', os.path.join(os.path.dirname(__file__), 'transform_imputed_data.py'), '--data_url={!s}'.format(args.data_url), '--imputed_rdd={!s}'.format(out_url), '--fold_idx={!s}'.format(args.fold_idx), '--valid_fold_idx=-1', '--num_percentiles=100']
# subprocess.check_call(cmd)