-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
420 lines (367 loc) · 26.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import os
import sys
import torch
import pickle
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from copy import deepcopy
from tabulate import tabulate
from utils import evaluate_set
from colorama import Fore, Style
from maml_vae import main_meta_train
from joblib import Parallel, delayed
from gen_model.data import sample_cat
from divergence_estimation.utils import store_results
from gen_model.data_generation.main_generator import main as main_gen
from gen_model.data_generation.main_gen_sota import main as main_gen_sota
def train_gen_model(args):
print(Fore.GREEN + 'Training generative model' + Style.RESET_ALL)
if args['model'] in ['ctgan', 'tvae']:
args['real_df'].to_csv(args['output_dir'] + '/real_data.csv', index=False)
args['real_df'] = None
main_gen_sota(args)
elif args['model'] in ['vae']:
for seed in range(args['n_seeds']):
log_name = os.path.join(args['output_dir'], 'seed_' + str(seed))
args['real_df'].to_csv(log_name + '_real_data.csv', index=False) # Save real data, for future use
args['real_df'] = None
args['n_threads'] = 1 # We do not parallelize the VAE training, although we could (something weird happens with our machine when doing this). Maybe we could move it to GPU...
main_gen(args)
else:
raise RuntimeError('Generative model not recognized')
def validation_method(n, m, l, new_seed, args, cfg):
# Create folder to store results
os.makedirs(args['output_dir'] + '/kl', exist_ok=True)
os.makedirs(args['output_dir'] + '/js', exist_ok=True)
os.makedirs(args['output_dir'] + '/kl' + f'/{n}_{m}_{l}', exist_ok=True)
os.makedirs(args['output_dir'] + '/js' + f'/{n}_{m}_{l}', exist_ok=True)
# Load data
x_real = args['data_val']
x_gen = args['syn_df']
col_names = x_real.columns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# Dataframe to tensor
x_real = torch.tensor(x_real.values, dtype=torch.float32, device=device)
x_gen = torch.tensor(x_gen.values, dtype=torch.float32, device=device)
# This function calls the divergence estimation methods and saves the results
evaluate_set(x_real, x_gen, n, m, l, new_seed, args['output_dir'], pre_path=None, case='data', tsne_flag=False,
l_gt=None, cfg=cfg, pr=None, ps=None, dataset_name=args['dataset_name'])
def evaluate_gen_model(n, m, l, args): # Evaluate JS / KL and save results
seeds = [i for i in range(args['runs'])]
# Create cfg object with the configuration of the validation method
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', default=10000, type=int, help='Number of epochs to train the discriminator')
parser.add_argument('--save_model', default=True, type=bool, help='Save the model')
parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
parser.add_argument('--use_pretrained', default=False, type=bool, help='Whether to use a pretrained model')
parser.add_argument('--print_feat_js', default=False, type=bool, help='Whether to print the JS per feature')
parser.add_argument('--name', default=args['dataset_name'], type=str, help='Name of the data')
cfg = parser.parse_args()
# Run experiments
_ = Parallel(n_jobs=args['n_threads'], verbose=10)(delayed(validation_method)(n, m, l, seed, args, cfg)
for seed in seeds)
# Create folder to store results
store_results(args['output_dir'], cfg)
def case_run(n, m, l, separate_training_evaluation, case_name, datasets, gen_methods, args, gen=True, validation=True, methodology=False, methodology_params=None, marginal_plot=True):
for dataset in datasets:
# Prepare the data, so that it is the same for all generative models
train_data = pd.read_csv(os.path.join(args['input_dir'], 'processed_data', dataset, 'preprocessed_data_n.csv'))
metadata_file = open(os.path.join(args['input_dir'], 'processed_data', dataset, 'metadata.pkl'), 'rb')
metadata = pickle.load(metadata_file)
metadata_file.close()
cat_cols = [key for key, value in metadata['metadata'].columns.items() if value['sdtype'] == 'categorical']
train_data = sample_cat(train_data, cat_cols, n) # Sample data paying attention to having all categories!
args['real_df'] = train_data
args['metadata'] = metadata
data_m = []
data_l = []
for j, s in enumerate(separate_training_evaluation):
if s: # Separate sets for training and evaluation
dm = pd.read_csv(os.path.join(args['input_dir'], 'processed_data', dataset, 'preprocessed_data_m.csv'))
dl = pd.read_csv(os.path.join(args['input_dir'], 'processed_data', dataset, 'preprocessed_data_l.csv'))
data_m.append(sample_cat(dm, cat_cols, m[j]))
data_l.append(sample_cat(dl, cat_cols, 2 * l[j]))
else: # The same data is used for training and evaluation
data_m.append(sample_cat(train_data, cat_cols, m[j]))
data_l.append(sample_cat(train_data, cat_cols, 2 * l[j]))
for gen_method in gen_methods:
# If we apply the methodology proposed, and we want to pretrain, reload the synthetic data here
if methodology and methodology_params['phase'] == 'pre_train_synth':
gen_data_dir = os.path.join(args['output_dir'], 'low_data', dataset, gen_method)
if gen_method == 'vae': # Load the data from the best seed only
best_seed = pd.read_csv(os.path.join(gen_data_dir, 'best_parameters.csv'))['seed'].iloc[0]
train_data = pd.read_csv(os.path.join(gen_data_dir, 'seed_' + str(best_seed) + '_gen_data.csv'))
else:
train_data = pd.read_csv(os.path.join(gen_data_dir, 'gen_data.csv'))
print('Replacing real data as generative model input with synthetic data...')
train_data = sample_cat(train_data, cat_cols, n) # Sample data paying attention to having all categories!
args['real_df'] = train_data
elif methodology and methodology_params['phase'] == 'pre_train_synth_meta':
assert gen_method == 'vae', 'Only VAE is supported for meta-learning'
train_data = []
for seed in range(args['n_seeds']):
train_data.append(pd.read_csv(os.path.join(args['output_dir'], 'low_data', dataset, gen_method, 'seed_' + str(seed) + '_gen_data.csv')))
print('Replacing real data as generative model input with synthetic data for meta-learning...')
train_data = [sample_cat(d, cat_cols, n) for d in train_data] # Sample data paying attention to having all categories!
args['real_df'] = train_data
elif methodology and methodology_params['phase'] == 'pre_train_synth_meta_drs':
assert gen_method == 'vae', 'Only VAE is supported for meta-learning'
train_data = []
for seed in range(args['n_seeds']):
train_data.append(pd.read_csv(os.path.join(args['output_dir'], 'low_data', dataset, gen_method, 'seed_' + str(seed) + '_gen_data.csv')))
print('Replacing real data as generative model input with synthetic data for meta-learning...')
# Concatenate all the train data
train_data = pd.concat(train_data, axis=0) # Build a new dataset using data from all VAE seeds
train_data = sample_cat(train_data, cat_cols, n) # Sample data paying attention to having all categories!
args['real_df'] = train_data
# First, train the generative model.
args_gen = deepcopy(args)
args_gen['dataset_name'] = dataset
args_gen['model'] = gen_method
args_gen['case_name'] = case_name
args_gen['train'] = True
args_gen['eval'] = True
args_gen['generated_samples'] = 10000
args_gen['classifiers_list'] = ['MLP', 'RF']
args_gen['present_results'] = True
latent_dim = 10 # Default latent dimension
if dataset == 'news' or dataset == 'credit':
latent_dim = 20
elif dataset == 'king':
latent_dim = 15
args_gen['param_comb'] = [{'hidden_size': 256, 'latent_dim': latent_dim}] # Hyperparameters for the models
if gen_method in ['vae']: # VAE parameters
args_gen['imp_mask'] = False # We assume all data has been imputed during preprocessing
args_gen['mask_gen'] = False
args_gen['train_vae'] = True
args_gen['early_stop'] = True
args_gen['batch_size'] = 256
if methodology and methodology_params['phase'] == 'fine_tune': # If we are to fine-tune, then load the model
args_gen['use_pretrained'] = True
args_gen['pretrained_dir'] = os.path.join(args['output_dir'], case_name, dataset, gen_method, 'synthetic_pretrain')
elif methodology and methodology_params['phase'] == 'fine_tune_meta': # If we are to fine-tune, then load the model
args_gen['use_pretrained'] = True
args_gen['pretrained_dir'] = os.path.join(args['output_dir'], case_name, dataset, gen_method, 'meta_pretrain')
elif methodology and methodology_params['phase'] == 'fine_tune_avg': # If we are to fine-tune, then load the model
args_gen['use_pretrained'] = True
args_gen['pretrained_dir'] = os.path.join(args['output_dir'], case_name, dataset, gen_method, 'avg_pretrain')
else:
args_gen['use_pretrained'] = False
args_gen['pretrained_dir'] = None
if methodology and methodology_params['phase'] == 'pre_train_synth': # If we are to pretrain, then save the models in a different folder
args_gen['output_dir'] = os.path.join(args['output_dir'], case_name, dataset, gen_method, 'synthetic_pretrain')
elif methodology and (methodology_params['phase'] == 'pre_train_synth_meta' or methodology_params['phase'] == 'pre_train_synth_meta_drs'):
args_gen['output_dir'] = os.path.join(args['output_dir'], case_name, dataset, gen_method, 'meta_pretrain')
elif methodology and methodology_params['phase'] == 'pre_train_synth_avg':
args_gen['output_dir'] = os.path.join(args['output_dir'], case_name, dataset, gen_method, 'avg_pretrain')
else:
args_gen['output_dir'] = os.path.join(args['output_dir'], case_name, dataset, gen_method)
os.makedirs(args_gen['output_dir'], exist_ok=True) # Ensure that the path exists
if gen and methodology and methodology_params['phase'] == 'pre_train_synth_meta': # In this case, train the MAML meta-learner
main_meta_train(args_gen)
elif gen and methodology and methodology_params['phase'] == 'pre_train_synth_meta_drs': # In this case, train the DRS meta-learner
args_gen['n_seeds'] = 1 # Generate a single seed for the DRS pretraining
train_gen_model(args_gen) # DRS is a "standard" pretrain
elif gen and methodology and methodology_params['phase'] == 'pre_train_synth_avg': # In this case, we just compute an average model
models = [torch.load(os.path.join(args['output_dir'], 'low_data', dataset, gen_method, 'seed_' + str(seed))) for seed in range(args_gen['n_seeds'])]
new_model = deepcopy(models[0])
for key in new_model.keys():
avg_val = sum([m[key] for m in models]) / args_gen['n_seeds']
new_model[key] = avg_val
torch.save(new_model, os.path.join(args_gen['output_dir'], 'avg_pretrain'))
elif gen: # Note that the parameters may be used by the evaluation method
train_gen_model(args_gen) # Train and save results
# Prepare for evaluation (there might be several situations to validate!)
if validation:
if gen_method == 'vae': # Evaluate only the best seed
best_seed = pd.read_csv(os.path.join(args_gen['output_dir'], 'best_parameters.csv'))['seed'].iloc[0]
syn_df = pd.read_csv(os.path.join(args_gen['output_dir'], 'seed_' + str(best_seed) + '_gen_data.csv'))
else:
syn_df = pd.read_csv(os.path.join(args_gen['output_dir'], 'gen_data.csv'))
for j in range(len(separate_training_evaluation)):
args_val = deepcopy(args_gen)
args_val['n_threads'] = args['n_threads']
args_val['data_val'] = pd.concat([data_m[j], data_l[j]], axis=0)
# Now, evaluate the generative model using KL and JS divergences
args_val['output_dir'] = os.path.join(args_gen['output_dir'], 'validation')
args_val['runs'] = 5 # Number of different KL and JS estimators to use
args_val['syn_df'] = sample_cat(syn_df, cat_cols, m[j] + 2 * l[j]) # Sample the synthetic data
evaluate_gen_model(n, m[j], l[j], args_val) # Evaluate and save results
if marginal_plot: # Save the marginal histograms of each feature
if methodology and methodology_params['phase'] == 'pre_train_synth_meta':
pass # Nothing to plot in this case
else:
dfs = [args['real_df']]
names = ['real']
for gen_method in gen_methods:
if methodology and methodology_params['phase'] == 'pre_train_synth': # If we are to pretrain, then save the models in a different folder
output_dir = os.path.join(args['output_dir'], case_name, dataset, gen_method, 'synthetic_pretrain')
else:
output_dir = os.path.join(args['output_dir'], case_name, dataset, gen_method)
if gen_method == 'vae': # Evaluate only the best seed
best_seed = pd.read_csv(os.path.join(output_dir, 'best_parameters.csv'))['seed'].iloc[0]
syn_df = pd.read_csv(os.path.join(output_dir, 'seed_' + str(best_seed) + '_gen_data.csv'))
else:
syn_df = pd.read_csv(os.path.join(output_dir, 'gen_data.csv'))
dfs.append(syn_df)
names.append(gen_method)
for lab in args['real_df'].columns.tolist():
plt.hist([df[lab] for df in dfs], bins=10, label=names, density=True)
plt.title(f'Marginal distribution of {lab} in {dataset}')
plt.legend(loc='best')
plt.savefig(os.path.join(args['output_dir'], case_name, dataset, 'marginal_' + lab + '.png'))
plt.close()
if __name__ == '__main__':
## MAIN PARAMETERS OF THE CODE: CHANGE THINGS HERE
datasets = ['adult', 'news', 'king', 'intrusion']
gen_methods = ['vae', 'ctgan'] # For correct results show, the VAE should be the first one
args = {'cl_datasets': ['adult', 'news', 'intrusion'],
'reg_datasets': ['king'],
'n_threads': 10, # Number of threads for parallelization in the validation phase
'output_dir': os.path.join(Path(sys.argv[0]).resolve().parent, 'results'),
'input_dir': os.path.join(Path(sys.argv[0]).resolve().parent, 'data'),
'n_seeds': 10, # Number of seeds for the VAE
}
train_methods = not True # Flag to train everything
gen = not True # To train the generative models or the meta-learner
validation = not True # To evaluate the generative models
marginal_plot = not True # To store the marginal plots
show_results = True # Flag to show the results
n_large = 10000
m_large = 7500
l_large = 1000
n_low = 300
m_low = 100
l_low = 100
if train_methods: # Note that, depending on the computational capabilities, as well as the dataset and cases selected, this may take a long time
# CASE 1: BIG DATA (N=10000, M=7500, L=1000), separate training and evaluation set (ideal conditions)
# Note that M, L, and separate_training_evaluation are lists, so that we can run several experiments at once (they must correspond to each other)
case_run(n_large, [m_large], [l_large], [True], 'big_data', datasets, gen_methods, args, methodology=False,
gen=gen, validation=validation, marginal_plot=marginal_plot)
# CASE 2: REALISTIC SITUATION: lower number of samples, and two validations: a realistic one (low number of samples, no separation), and an ideal one (more samples, separation), the latter to evaluate the divergence estimator
case_run(n_low, [m_large, m_low], [l_large, l_low], [True, False], 'low_data', datasets, gen_methods, args, methodology=False,
gen=gen, validation=validation, marginal_plot=marginal_plot)
# CASE 3: PRETRAINING + TRANSFER LEARNING
methodology_params = {
'phase': 'pre_train_synth'} # Adjusts everything for pretraining using the case_2 generated data
case_run(n_large, [m_large, m_low], [l_large, l_low], [True, False], 'pretrain', datasets, gen_methods, args,
methodology=True,
methodology_params=methodology_params, gen=gen, validation=validation, marginal_plot=marginal_plot)
methodology_params = {'phase': 'fine_tune'} # Adjusts everything to fine-tune using the pretrained model
case_run(n_low, [m_large, m_low], [l_large, l_low], [True, False], 'pretrain', datasets, gen_methods, args, methodology=True,
methodology_params=methodology_params, gen=gen, validation=validation, marginal_plot=marginal_plot)
if 'vae' in gen_methods:
# CASE 4: MAML META LEARNING + TRANSFER LEARNING (only for VAE)
methodology_params = {'phase': 'pre_train_synth_meta'} # Adjusts everything for pretraining using the case_2 generated data
case_run(n_large, [m_low], [l_low], [True], 'maml', datasets, ['vae'], args, methodology=True,
methodology_params=methodology_params, gen=gen, validation=False, marginal_plot=False) # Do NOT validate here (non-sense), note that m, l and separate_training_evaluation are not used here
methodology_params = {'phase': 'fine_tune_meta'} # Adjusts everything to fine-tune using the pretrained model
case_run(n_low, [m_large, m_low], [l_large, l_low], [True, False], 'maml', datasets, ['vae'], args, methodology=True,
methodology_params=methodology_params, gen=gen, validation=validation, marginal_plot=marginal_plot)
# CASE 5: DRS META LEARNING + TRANSFER LEARNING (only for VAE)
methodology_params = {'phase': 'pre_train_synth_meta_drs'} # Adjusts everything for pretraining using the case_2 generated data
case_run(n_large, [m_low], [l_low], [True], 'drs', datasets, ['vae'], args, methodology=True,
methodology_params=methodology_params, gen=gen, validation=False, marginal_plot=False) # Do NOT validate here (non-sense), note that m, l and separate_training_evaluation are not used here
methodology_params = {'phase': 'fine_tune_meta'} # Adjusts everything to fine-tune using the pretrained model
case_run(n_low, [m_large, m_low], [l_large, l_low], [True, False], 'drs', datasets, ['vae'], args, methodology=True,
methodology_params=methodology_params, gen=gen, validation=validation, marginal_plot=marginal_plot)
# CASE 6: MODEL AVERAGING + TRANSFER LEARNING (only for VAE)
methodology_params = {'phase': 'pre_train_synth_avg'} # Adjusts everything for pretraining using the case_2 generated data
case_run(n_large, [m_low], [l_low], [True], 'avg', datasets, ['vae'], args, methodology=True,
methodology_params=methodology_params, gen=gen, validation=False, marginal_plot=False) # Do NOT validate here (non-sense), note that m, l and separate_training_evaluation are not used here
methodology_params = {'phase': 'fine_tune_avg'} # Adjusts everything to fine-tune using the pretrained model
case_run(n_low, [m_large, m_low], [l_large, l_low], [True, False], 'avg', datasets, ['vae'], args, methodology=True,
methodology_params=methodology_params, gen=gen, validation=validation, marginal_plot=marginal_plot)
if show_results: # Show the results obtained
js_per_dataset_metrics = {'pretrain': [], 'maml': [], 'drs': [], 'avg': []}
kl_per_dataset_metrics = {'pretrain': [], 'maml': [], 'drs': [], 'avg': []}
vae_only_cases = ['maml', 'drs', 'avg']
for dataset in datasets:
print(f"\n Results for dataset {dataset}")
tab = []
names = ['Case', 'N', 'M', 'L'] + [f"{gen_method} JS" for gen_method in gen_methods] + [f"{gen_method} KL" for gen_method in gen_methods]
# for case in ['big_data', 'low_data', 'pretrain', 'avg', 'maml', 'drs']:
for case in ['big_data', 'low_data', 'pretrain', 'avg', 'maml', 'drs']:
dir = os.path.join(args['output_dir'], case, dataset, gen_methods[0], 'validation')
d = pd.read_csv(dir + '/js.csv') # There might be different validation methods
for i in range(len(d)):
# Get N, M and L from the first folder name
n, m, l = d['n'].iloc[i], d['m'].iloc[i], d['l'].iloc[i]
t = [case, n, m, l]
for gen_method in gen_methods:
if case in vae_only_cases and gen_method != "vae":
t.append("N/A")
else:
dir = os.path.join(args['output_dir'], case, dataset, gen_method, 'validation')
js = pd.read_csv(os.path.join(dir, 'js.csv'))['JS Discriminator'].iloc[i]
js_std = pd.read_csv(os.path.join(dir, 'js_std.csv'))['JS Discriminator'].iloc[i]
t.append(f"{js:.3f} ({js_std:.3f})")
for gen_method in gen_methods:
if case in vae_only_cases and gen_method != "vae":
t.append("N/A")
else:
dir = os.path.join(args['output_dir'], case, dataset, gen_method, 'validation')
kl = pd.read_csv(os.path.join(dir, 'kl.csv'))['KL Discriminator'].iloc[i]
kl_std = pd.read_csv(os.path.join(dir, 'kl_std.csv'))['KL Discriminator'].iloc[i]
t.append(f"{kl:.3f} ({kl_std:.3f})")
tab.append(t)
# Add the gains observed in case 3
t = [f"pretrain gain", "-", f"{m_large}", f"{l_large}"]
js_gain = []
for gen_method in gen_methods:
dir = os.path.join(args['output_dir'], 'low_data', dataset, gen_method, 'validation')
df = pd.read_csv(os.path.join(dir, 'js.csv')) # Important note: we compute gains over the case with large number of samples
js_base = df.loc[df['m'] == m_large]['JS Discriminator'].values[0]
dir = os.path.join(args['output_dir'], 'pretrain', dataset, gen_method, 'validation')
df = pd.read_csv(os.path.join(dir, 'js.csv'))
js = df.loc[df['m'] == m_large]['JS Discriminator'].values[0]
# t.append(f"{js_base - js:.3f}")
t.append(f"{(js_base - js):.3f} ({((js_base - js)/js_base):.3f})")
js_gain.append(f"{(js_base - js):.3f} ({((js_base - js)/js_base):.3f})")
# js_gain.append(js_base - js)
js_per_dataset_metrics['pretrain'].append(js_gain)
kl_gain = []
for gen_method in gen_methods:
dir = os.path.join(args['output_dir'], 'low_data', dataset, gen_method, 'validation')
df = pd.read_csv(os.path.join(dir, 'kl.csv')) # Important note: we compute gains over the case with large number of samples
kl_base = df.loc[df['m'] == m_large]['KL Discriminator'].values[0]
dir = os.path.join(args['output_dir'], 'pretrain', dataset, gen_method, 'validation')
df = pd.read_csv(os.path.join(dir, 'kl.csv'))
kl = df.loc[df['m'] == m_large]['KL Discriminator'].values[0]
# t.append(f"{kl_base - kl:.3f}")
t.append(f"{(kl_base - kl):.3f} ({((kl_base - kl) / kl_base):.3f})")
# kl_gain.append(kl_base - kl)
kl_gain.append(f"{(kl_base - kl):.3f} ({((kl_base - kl) / kl_base):.3f})")
kl_per_dataset_metrics['pretrain'].append(kl_gain)
tab.append(t)
# Add the gains observed in case 4, 5 and 6: the base metrics are common to the three cases
dir = os.path.join(args['output_dir'], 'low_data', dataset, 'vae', 'validation')
df = pd.read_csv(os.path.join(dir, 'js.csv'))
js_base = df.loc[df['m'] == m_large]['JS Discriminator'].values[0]
df = pd.read_csv(os.path.join(dir, 'kl.csv'))
kl_base = df.loc[df['m'] == m_large]['KL Discriminator'].values[0]
na_string = ["N/A" for gen_method in gen_methods if gen_method != "vae"]
for case in vae_only_cases:
dir = os.path.join(args['output_dir'], case, dataset, 'vae', 'validation')
df = pd.read_csv(os.path.join(dir, 'js.csv'))
js = df.loc[df['m'] == m_large]['JS Discriminator'].values[0]
df = pd.read_csv(os.path.join(dir, 'kl.csv'))
kl = df.loc[df['m'] == m_large]['KL Discriminator'].values[0]
tab.append(
[f"{case} gain", "-", f"{m_large}", f"{l_large}", f"{js_base - js:.3f} ({((js_base - js)/js_base):.3f})"] + na_string + [f"{kl_base - kl:.3f} ({((kl_base - kl) / kl_base):.3f})"]
+ na_string)
js_per_dataset_metrics[case].append(js_base - js)
kl_per_dataset_metrics[case].append(kl_base - kl)
print(tabulate(tab, headers=names, tablefmt='orgtbl'))
# print(tabulate(tab, headers=names, tablefmt='latex'))
print('Average GAIN in pretrain')
for i, gen_method in enumerate(gen_methods):
print(f"JS {gen_method}: {np.mean([float((d[i]).split(' ')[0]) for d in js_per_dataset_metrics['pretrain']])}")
print(f"KL {gen_method}: {np.mean([float((d[i]).split(' ')[0]) for d in kl_per_dataset_metrics['pretrain']])}")
for case in vae_only_cases:
print(f"JS Average GAIN in case {case} for VAE: {np.mean(js_per_dataset_metrics[case])}")
print(f"KL Average GAIN in case {case} for VAE: {np.mean(kl_per_dataset_metrics[case])}")