diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..c3f17c4 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index b6e4761..5f8d199 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ __pycache__/ *.so # Distribution / packaging +data/ +.idea/ .Python build/ develop-eggs/ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..73f69e0 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/.idea/DU-VAE.iml b/.idea/DU-VAE.iml new file mode 100644 index 0000000..8b8c395 --- /dev/null +++ b/.idea/DU-VAE.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..cef7982 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,68 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..3999087 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..03505bc --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..d456b34 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Junxian He + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/config/.DS_Store b/config/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/config/.DS_Store differ diff --git a/config/config_omniglot.py b/config/config_omniglot.py new file mode 100755 index 0000000..6baf81f --- /dev/null +++ b/config/config_omniglot.py @@ -0,0 +1,12 @@ +params={ + 'img_size': [1,28,28], + 'nz': 32, + 'enc_layers': [64, 64, 64], + 'dec_kernel_size': [7, 7, 7, 7, 7, 5, 5, 5, 5, 3, 3, 3, 3], + 'dec_layers': [32,32,32,32,32,32,32,32,32,32,32,32], + 'latent_feature_map': 4, + 'batch_size': 50, + 'epochs': 500, + 'test_nepoch': 5, + 'data_file': 'data/omniglot_data/omniglot.pt' +} diff --git a/config/config_omniglot_ss.py b/config/config_omniglot_ss.py new file mode 100644 index 0000000..cdc448e --- /dev/null +++ b/config/config_omniglot_ss.py @@ -0,0 +1,7 @@ +params={ + 'img_size': [1,28,28], + 'nz': 32, + 'latent_feature_map': 4, + 'test_nepoch': 5, + 'root': 'data/omniglot_data/' +} \ No newline at end of file diff --git a/config/config_short_yelp.py b/config/config_short_yelp.py new file mode 100755 index 0000000..236fbd6 --- /dev/null +++ b/config/config_short_yelp.py @@ -0,0 +1,155 @@ +params={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 128, + 'enc_nh': 512, + 'dec_nh': 512, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + 'batch_size': 32, + 'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/short_yelp_data/short_yelp.train.txt', + 'val_data': 'data/short_yelp_data/short_yelp.valid.txt', + 'test_data': 'data/short_yelp_data/short_yelp.test.txt', + 'vocab_file': 'data/short_yelp_data/vocab.txt', + "label": True +} + + +params_ss_10={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 128, + 'enc_nh': 512, + 'dec_nh': 512, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/short_yelp_data/short_yelp.train.10.txt', + 'val_data': 'data/short_yelp_data/short_yelp.valid.txt', + 'test_data': 'data/short_yelp_data/short_yelp.test.txt', + # 'vocab_file': 'data/short_yelp_data/vocab.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 10, + "label": True +} + + +params_ss_100={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 128, + 'enc_nh': 512, + 'dec_nh': 512, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/short_yelp_data/short_yelp.train.100.txt', + 'val_data': 'data/short_yelp_data/short_yelp.valid.txt', + 'test_data': 'data/short_yelp_data/short_yelp.test.txt', + # 'vocab_file': 'data/short_yelp_data/vocab.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 10, + "label": True +} + +params_ss_500={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 128, + 'enc_nh': 512, + 'dec_nh': 512, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + #'batch_size': 50, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/short_yelp_data/short_yelp.train.500.txt', + 'val_data': 'data/short_yelp_data/short_yelp.valid.txt', + 'test_data': 'data/short_yelp_data/short_yelp.test.txt', + # 'vocab_file': 'data/short_yelp_data/vocab.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 10, + "label": True +} + +params_ss_1000={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 128, + 'enc_nh': 512, + 'dec_nh': 512, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/short_yelp_data/short_yelp.train.1000.txt', + 'val_data': 'data/short_yelp_data/short_yelp.valid.txt', + 'test_data': 'data/short_yelp_data/short_yelp.test.txt', + # 'vocab_file': 'data/short_yelp_data/vocab.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 10, + "label": True +} + + +params_ss_2000={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 128, + 'enc_nh': 512, + 'dec_nh': 512, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/short_yelp_data/short_yelp.train.2000.txt', + 'val_data': 'data/short_yelp_data/short_yelp.valid.txt', + 'test_data': 'data/short_yelp_data/short_yelp.test.txt', + # 'vocab_file': 'data/short_yelp_data/vocab.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 10, + "label": True +} + + +params_ss_10000={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 128, + 'enc_nh': 512, + 'dec_nh': 512, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/short_yelp_data/short_yelp.train.10000.txt', + 'val_data': 'data/short_yelp_data/short_yelp.valid.txt', + 'test_data': 'data/short_yelp_data/short_yelp.test.txt', + # 'vocab_file': 'data/short_yelp_data/vocab.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 10, + "label": True +} diff --git a/config/config_synthetic.py b/config/config_synthetic.py new file mode 100755 index 0000000..31649bf --- /dev/null +++ b/config/config_synthetic.py @@ -0,0 +1,19 @@ + +params={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 2, + 'ni': 50, + 'enc_nh': 50, + 'dec_nh': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + 'epochs': 50, + 'batch_size': 32, + 'test_nepoch': 1, + 'train_data': 'data/synthetic_data/synthetic_train.txt', + 'val_data': 'data/synthetic_data/synthetic_test.txt', + 'test_data': 'data/synthetic_data/synthetic_test.txt', + 'vocab_file': 'data/synthetic_data/vocab.txt', + "label": True +} diff --git a/config/config_yahoo.py b/config/config_yahoo.py new file mode 100755 index 0000000..48c8f0b --- /dev/null +++ b/config/config_yahoo.py @@ -0,0 +1,18 @@ + +params={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + 'batch_size': 32, + 'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/yahoo_data/yahoo.train.txt', + 'val_data': 'data/yahoo_data/yahoo.valid.txt', + 'test_data': 'data/yahoo_data/yahoo.test.txt', + 'vocab_file': 'data/yahoo_data/vocab.txt' +} diff --git a/config/config_yelp.py b/config/config_yelp.py new file mode 100755 index 0000000..f9d0d64 --- /dev/null +++ b/config/config_yelp.py @@ -0,0 +1,149 @@ + +params={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + 'batch_size': 32, + 'epochs': 150, + 'test_nepoch': 5, + 'train_data': 'data/yelp_data/yelp.train.txt', + 'val_data': 'data/yelp_data/yelp.valid.txt', + 'test_data': 'data/yelp_data/yelp.test.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'label':True +} + + +params_ss_10={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/yelp_data/yelp.train.10.txt', + 'val_data': 'data/yelp_data/yelp.valid.txt', + 'test_data': 'data/yelp_data/yelp.test.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 5, + "label": True +} + + +params_ss_100={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/yelp_data/yelp.train.100.txt', + 'val_data': 'data/yelp_data/yelp.valid.txt', + 'test_data': 'data/yelp_data/yelp.test.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 5, + "label": True +} + +params_ss_500={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + #'batch_size': 50, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/yelp_data/yelp.train.500.txt', + 'val_data': 'data/yelp_data/yelp.valid.txt', + 'test_data': 'data/yelp_data/yelp.test.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 5, + "label": True +} + +params_ss_1000={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/yelp_data/yelp.train.1000.txt', + 'val_data': 'data/yelp_data/yelp.valid.txt', + 'test_data': 'data/yelp_data/yelp.test.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 5, + "label": True +} + + +params_ss_2000={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/yelp_data/yelp.train.2000.txt', + 'val_data': 'data/yelp_data/yelp.valid.txt', + 'test_data': 'data/yelp_data/yelp.test.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 5, + "label": True +} + + +params_ss_10000={ + 'enc_type': 'lstm', + 'dec_type': 'lstm', + 'nz': 32, + 'ni': 512, + 'enc_nh': 1024, + 'dec_nh': 1024, + 'log_niter': 50, + 'dec_dropout_in': 0.5, + 'dec_dropout_out': 0.5, + # 'batch_size': 32, + #'epochs': 100, + 'test_nepoch': 5, + 'train_data': 'data/yelp_data/yelp.train.10000.txt', + 'val_data': 'data/yelp_data/yelp.valid.txt', + 'test_data': 'data/yelp_data/yelp.test.txt', + 'vocab_file': 'data/yelp_data/vocab.txt', + 'ncluster': 5, + "label": True +} \ No newline at end of file diff --git a/exp_utils.py b/exp_utils.py new file mode 100755 index 0000000..6f55642 --- /dev/null +++ b/exp_utils.py @@ -0,0 +1,43 @@ +import functools +import os, shutil + +import numpy as np + +import torch + + +def logging(s, log_path, print_=True, log_=True): + print(s) + # if print_: + # print(s) + if log_: + with open(log_path, 'a+') as f_log: + f_log.write(s + '\n') + +def get_logger(log_path, **kwargs): + return functools.partial(logging, log_path=log_path, **kwargs) + +def create_exp_dir(dir_path, scripts_to_save=None, debug=False): + if debug: + print('Debug Mode : no experiment dir created') + return functools.partial(logging, log_path=None, log_=False) + + if os.path.exists(dir_path): + print("Path {} exists. Remove and remake.".format(dir_path)) + shutil.rmtree(dir_path) + + os.makedirs(dir_path) + + print('Experiment dir : {}'.format(dir_path)) + if scripts_to_save is not None: + script_path = os.path.join(dir_path, 'scripts') + if not os.path.exists(script_path): + os.makedirs(script_path) + for script in scripts_to_save: + dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) + + return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + +def save_checkpoint(model, optimizer, path, epoch): + torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) \ No newline at end of file diff --git a/image.py b/image.py new file mode 100755 index 0000000..0b20da2 --- /dev/null +++ b/image.py @@ -0,0 +1,406 @@ +import sys +import os +import time +import importlib +import argparse + +import numpy as np + +import torch +import torch.utils.data +from torch import optim + +from modules import ResNetEncoderV2, BNResNetEncoderV2, PixelCNNDecoderV2 +from modules import VAE +from logger import Logger +from utils import calc_mi + +clip_grad = 5.0 +decay_epoch = 20 +lr_decay = 0.5 +max_decay = 5 + + +def init_config(): + parser = argparse.ArgumentParser(description='VAE mode collapse study') + + # model hyperparameters + parser.add_argument('--dataset', default='omniglot', type=str, help='dataset to use') + + # optimization parameters + parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') + parser.add_argument('--iw_nsamples', type=int, default=500, + help='number of samples to compute importance weighted estimate') + # select mode + parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') + parser.add_argument('--load_path', type=str, default='') + # annealing paramters + parser.add_argument('--warm_up', type=int, default=10) + parser.add_argument('--kl_start', type=float, default=1.0) + + # these are for slurm purpose to save model + parser.add_argument('--jobid', type=int, default=0, help='slurm job id') + parser.add_argument('--taskid', type=int, default=0, help='slurm task id') + parser.add_argument('--device', type=str, default="cpu") + parser.add_argument('--delta_rate', type=float, default=1.0, + help=" coontrol the minization of the variation of latent variables") + parser.add_argument('--gamma', type=float, default=0.5) # BN-VAE + parser.add_argument("--reset_dec", action="store_true", default=False) + parser.add_argument("--nz_new", type=int, default=32) # myGaussianLSTMencoder + parser.add_argument('--p_drop', type=float, default=0.2) # p \in [0, 1] + + args = parser.parse_args() + if 'cuda' in args.device: + args.cuda = True + else: + args.cuda = False + + load_str = "_load" if args.load_path != "" else "" + save_dir = "models/%s%s/" % (args.dataset, load_str) + + + if args.warm_up > 0 and args.kl_start < 1.0: + cw_str = '_warm%d' % args.warm_up + else: + cw_str = '' + + hkl_str = 'KL%.2f' % args.kl_start + drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else '' + + seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] + args.seed = seed_set[args.taskid] + + if args.gamma > 0: + gamma_str = '_gamma%.2f' % (args.gamma) + else: + gamma_str = '' + + id_ = "%s_%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d" % \ + (args.dataset, hkl_str, + cw_str, load_str, gamma_str, args.delta_rate, + args.nz_new,drop_str, + args.jobid, args.taskid, args.seed) + + save_dir += id_ + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + save_path = os.path.join(save_dir, 'model.pt') + + args.save_path = save_path + print("save path", args.save_path) + + args.log_path = os.path.join(save_dir, "log.txt") + print("log path", args.log_path) + + # load config file into args + config_file = "config.config_%s" % args.dataset + params = importlib.import_module(config_file).params + args = argparse.Namespace(**vars(args), **params) + if args.nz != args.nz_new: + args.nz = args.nz_new + print('args.nz', args.nz) + + if 'label' in params: + args.label = params['label'] + else: + args.label = False + + args.kl_weight = 1 + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + return args + + +def test(model, test_loader, mode, args): + report_kl_loss = report_kl_t_loss = report_rec_loss = 0 + report_num_examples = 0 + mutual_info = [] + for datum in test_loader: + batch_data, _ = datum + batch_data = batch_data.to(args.device) + + batch_size = batch_data.size(0) + + report_num_examples += batch_size + + loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False) + loss_kl_t = model.KL(batch_data, args) + + assert (not loss_rc.requires_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + loss_kl_t = loss_kl_t.sum() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + report_kl_t_loss += loss_kl_t.item() + + mutual_info = calc_mi(model, test_loader, device=args.device) + + test_loss = (report_rec_loss + report_kl_loss) / report_num_examples + + nll = (report_kl_t_loss + report_rec_loss) / report_num_examples + kl = report_kl_loss / report_num_examples + kl_t = report_kl_t_loss / report_num_examples + + print('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f' % \ + (mode, test_loss, report_kl_t_loss / report_num_examples, mutual_info, + report_rec_loss / report_num_examples, nll)) + sys.stdout.flush() + + return test_loss, nll, kl_t ##返回真实的kl_t 不是训练中的kl + + +def calc_au(model, test_loader, delta=0.01): + """compute the number of active units + """ + means = [] + for datum in test_loader: + batch_data, _ = datum + + batch_data = batch_data.to(args.device) + + mean, _ = model.encode_stats(batch_data) + means.append(mean) + + means = torch.cat(means, dim=0) + au_mean = means.mean(0, keepdim=True) + + # (batch_size, nz) + au_var = means - au_mean + ns = au_var.size(0) + + au_var = (au_var ** 2).sum(dim=0) / (ns - 1) + + return (au_var >= delta).sum().item(), au_var + + +def calc_iwnll(model, test_loader, args): + report_nll_loss = 0 + report_num_examples = 0 + for id_, datum in enumerate(test_loader): + batch_data, _ = datum + batch_data = batch_data.to(args.device) + + batch_size = batch_data.size(0) + + report_num_examples += batch_size + + if id_ % (round(len(test_loader) / 10)) == 0: + print('iw nll computing %d0%%' % (id_ / (round(len(test_loader) / 10)))) + sys.stdout.flush() + + loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples) + + report_nll_loss += loss.sum().item() + + nll = report_nll_loss / report_num_examples + + print('iw nll: %.4f' % nll) + sys.stdout.flush() + return nll + + +def main(args): + if args.cuda: + print('using cuda') + print(args) + + args.device = torch.device(args.device) + device = args.device + + opt_dict = {"not_improved": 0, "lr": 0.001, "best_loss": 1e4} + + all_data = torch.load(args.data_file) + x_train, x_val, x_test = all_data + if args.dataset == 'omniglot': + + x_train = x_train.to(device) + x_val = x_val.to(device) + x_test = x_test.to(device) + y_size = 1 + y_train = x_train.new_zeros(x_train.size(0), y_size) + y_val = x_train.new_zeros(x_val.size(0), y_size) + y_test = x_train.new_zeros(x_test.size(0), y_size) + + print(torch.__version__) + train_data = torch.utils.data.TensorDataset(x_train, y_train) + val_data = torch.utils.data.TensorDataset(x_val, y_val) + test_data = torch.utils.data.TensorDataset(x_test, y_test) + + + train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True) + val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True) + print('Train data: %d batches' % len(train_loader)) + print('Val data: %d batches' % len(val_loader)) + print('Test data: %d batches' % len(test_loader)) + sys.stdout.flush() + + log_niter = len(train_loader) // 5 + + if args.gamma > 0: + encoder = BNResNetEncoderV2(args) + else: + encoder = ResNetEncoderV2(args) + + decoder = PixelCNNDecoderV2(args) + + vae = VAE(encoder, decoder, args).to(device) + + if args.eval: + print('begin evaluation') + test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) + vae.load_state_dict(torch.load(args.load_path)) + vae.eval() + with torch.no_grad(): + test(vae, test_loader, "TEST", args) + au, au_var = calc_au(vae, test_loader) + print("%d active units" % au) + # print(au_var) + calc_iwnll(vae, test_loader, args) + return + + enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) + dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) + opt_dict['lr'] = 0.001 + + iter_ = 0 + best_loss = 1e4 + decay_cnt = 0 + vae.train() + start = time.time() + + kl_weight = args.kl_start + anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader)) + + for epoch in range(args.epochs): + + report_kl_loss = report_rec_loss = 0 + report_num_examples = 0 + for datum in train_loader: + batch_data, _ = datum + batch_data = batch_data.to(device) + if args.dataset != 'fashion-mnist': + batch_data = torch.bernoulli(batch_data) + batch_size = batch_data.size(0) + + report_num_examples += batch_size + + # kl_weight = 1.0 + + kl_weight = min(1.0, kl_weight + anneal_rate) + args.kl_weight = kl_weight + + enc_optimizer.zero_grad() + dec_optimizer.zero_grad() + + loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args) + + loss = loss.mean(dim=-1) + + loss.backward() + torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + + enc_optimizer.step() + dec_optimizer.step() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + + if iter_ % log_niter == 0: + + train_loss = (report_rec_loss + report_kl_loss) / report_num_examples + if epoch == 0: + vae.eval() + with torch.no_grad(): + mi = calc_mi(vae, val_loader, device=device) + au, _ = calc_au(vae, val_loader) + + vae.train() + + print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ + 'au %d, time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi, + report_rec_loss / report_num_examples, au, time.time() - start)) + else: + print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ + 'time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_examples, + report_rec_loss / report_num_examples, time.time() - start)) + sys.stdout.flush() + + report_rec_loss = report_kl_loss = 0 + report_num_examples = 0 + + iter_ += 1 + + print('kl weight %.4f' % args.kl_weight) + print('epoch: %d, VAL' % epoch) + + vae.eval() + + with torch.no_grad(): + loss, nll, kl = test(vae, val_loader, "VAL", args) + au, au_var = calc_au(vae, val_loader) + print("%d active units" % au) + # print(au_var) + + if loss < best_loss: + print('update best loss') + best_loss = loss + torch.save(vae.state_dict(), args.save_path) + + if loss > best_loss: + opt_dict["not_improved"] += 1 + if opt_dict["not_improved"] >= decay_epoch: + opt_dict["best_loss"] = loss + opt_dict["not_improved"] = 0 + opt_dict["lr"] = opt_dict["lr"] * lr_decay + vae.load_state_dict(torch.load(args.save_path)) + decay_cnt += 1 + print('new lr: %f' % opt_dict["lr"]) + enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"]) + dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"]) + else: + opt_dict["not_improved"] = 0 + opt_dict["best_loss"] = loss + + if decay_cnt == max_decay: + break + + if epoch % args.test_nepoch == 0: + with torch.no_grad(): + loss, nll, kl = test(vae, test_loader, "TEST", args) + + vae.train() + + # compute importance weighted estimate of log p(x) + vae.load_state_dict(torch.load(args.save_path)) + vae.eval() + with torch.no_grad(): + loss, nll, kl = test(vae, test_loader, "TEST", args) + au, au_var = calc_au(vae, test_loader) + print("%d active units" % au) + # print(au_var) + + test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) + + with torch.no_grad(): + calc_iwnll(vae, test_loader, args) + + +if __name__ == '__main__': + args = init_config() + if not args.eval: + sys.stdout = Logger(args.log_path) + main(args) diff --git a/image_IAF.py b/image_IAF.py new file mode 100644 index 0000000..ea5be8f --- /dev/null +++ b/image_IAF.py @@ -0,0 +1,419 @@ +import sys +import os +import time +import importlib +import argparse + +import numpy as np + +import torch +import torch.utils.data +# from torchvision.utils import save_image +from torch import nn, optim + +from modules import FlowResNetEncoderV2, PixelCNNDecoderV2 +from modules import VAEIAF as VAE +from logger import Logger +from utils import calc_mi + +clip_grad = 5.0 +decay_epoch = 20 +lr_decay = 0.5 +max_decay = 5 + + +def init_config(): + parser = argparse.ArgumentParser(description='VAE mode collapse study') + + # model hyperparameters + parser.add_argument('--dataset', default='omniglot', type=str, help='dataset to use') + + # optimization parameters + parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') + parser.add_argument('--iw_nsamples', type=int, default=500, + help='number of samples to compute importance weighted estimate') + # select mode + parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') + parser.add_argument('--load_path', type=str, default='') + # annealing paramters + parser.add_argument('--warm_up', type=int, default=10) + parser.add_argument('--kl_start', type=float, default=1.0) + # these are for slurm purpose to save model + parser.add_argument('--jobid', type=int, default=0, help='slurm job id') + parser.add_argument('--taskid', type=int, default=0, help='slurm task id') + parser.add_argument('--device', type=str, default="cpu") + parser.add_argument('--delta_rate', type=float, default=1.0, + help=" coontrol the minization of the variation of latent variables") + parser.add_argument('--gamma', type=float, default=0.5) # BN-VAE + parser.add_argument("--reset_dec", action="store_true", default=False) + parser.add_argument("--nz_new", type=int, default=32) # myGaussianLSTMencoder + parser.add_argument('--p_drop', type=float, default=0.15) # p \in [0, 1] + + parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow") + parser.add_argument('--flow_width', type=int, default=2, help="width of flow") + + parser.add_argument("--fb", type=int, default=0, + help="0: no fb; 1: ") + + parser.add_argument("--target_kl", type=float, default=0.0, + help="target kl of the free bits trick") + parser.add_argument('--drop_start', type=float, default=1.0, help="starting KL weight") + + args = parser.parse_args() + if 'cuda' in args.device: + args.cuda = True + else: + args.cuda = False + + load_str = "_load" if args.load_path != "" else "" + save_dir = "models/%s%s/" % (args.dataset, load_str) + + if args.warm_up > 0 and args.kl_start < 1.0: + cw_str = '_warm%d' % args.warm_up + '_%.2f' % args.kl_start + else: + cw_str = '' + + if args.fb == 0: + fb_str = "" + elif args.fb in [1, 2]: + fb_str = "_fb%d_tr%.2f" % (args.fb, args.target_kl) + + else: + fb_str = '' + + drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else '' + if 1.0 > args.drop_start > 0 and args.p_drop != 0: + drop_str += '_start%.2f' % args.drop_start + + seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] + args.seed = seed_set[args.taskid] + + if args.gamma > 0: + gamma_str = '_gamma%.2f' % (args.gamma) + + else: + gamma_str = '' + + if args.flow_depth > 0: + fd_str = '_fd%d_fw%d' % (args.flow_depth, args.flow_width) + + id_ = "%s%s%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d_IAF" % \ + (args.dataset, cw_str, load_str, gamma_str, fb_str, fd_str, + args.delta_rate, args.nz_new, drop_str, + args.jobid, args.taskid, args.seed) + save_dir += id_ + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + save_path = os.path.join(save_dir, 'model.pt') + + args.save_path = save_path + print("save path", args.save_path) + + args.log_path = os.path.join(save_dir, "log.txt") + print("log path", args.log_path) + + # load config file into args + config_file = "config.config_%s" % args.dataset + params = importlib.import_module(config_file).params + args = argparse.Namespace(**vars(args), **params) + if args.nz != args.nz_new: + args.nz = args.nz_new + print('args.nz', args.nz) + + if 'label' in params: + args.label = params['label'] + else: + args.label = False + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + return args + + +def test(model, test_loader, mode, args): + report_kl_loss = report_kl_t_loss = report_rec_loss = 0 + report_num_examples = 0 + mutual_info = [] + for datum in test_loader: + batch_data, _ = datum + batch_data = batch_data.to(args.device) + + batch_size = batch_data.size(0) + + report_num_examples += batch_size + + loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False) + loss_kl_t = model.KL(batch_data, args) + + assert (not loss_rc.requires_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + loss_kl_t = loss_kl_t.sum() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + report_kl_t_loss += loss_kl_t.item() + + mutual_info = calc_mi(model, test_loader, device=args.device) + + test_loss = (report_rec_loss + report_kl_loss) / report_num_examples + + nll = (report_kl_t_loss + report_rec_loss) / report_num_examples + kl = report_kl_loss / report_num_examples + kl_t = report_kl_t_loss / report_num_examples + + print('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f' % \ + (mode, test_loss, report_kl_t_loss / report_num_examples, mutual_info, + report_rec_loss / report_num_examples, nll)) + sys.stdout.flush() + + return test_loss, nll, kl_t ##返回真实的kl_t 不是训练中的kl + + +def calc_iwnll(model, test_loader, args): + report_nll_loss = 0 + report_num_examples = 0 + for id_, datum in enumerate(test_loader): + batch_data, _ = datum + batch_data = batch_data.to(args.device) + + batch_size = batch_data.size(0) + + report_num_examples += batch_size + + if id_ % (round(len(test_loader) / 10)) == 0: + print('iw nll computing %d0%%' % (id_ / (round(len(test_loader) / 10)))) + sys.stdout.flush() + + loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples) + + report_nll_loss += loss.sum().item() + + nll = report_nll_loss / report_num_examples + + print('iw nll: %.4f' % nll) + sys.stdout.flush() + return nll + +def calc_au(model, test_loader, delta=0.01): + """compute the number of active units + """ + means = [] + for datum in test_loader: + batch_data, _ = datum + + batch_data = batch_data.to(args.device) + + mean, _ = model.encode_stats(batch_data) + means.append(mean) + + means = torch.cat(means, dim=0) + au_mean = means.mean(0, keepdim=True) + + # (batch_size, nz) + au_var = means - au_mean + ns = au_var.size(0) + + au_var = (au_var ** 2).sum(dim=0) / (ns - 1) + + return (au_var >= delta).sum().item(), au_var + +def main(args): + if args.cuda: + print('using cuda') + print(args) + + args.device = torch.device(args.device) + device = args.device + + opt_dict = {"not_improved": 0, "lr": 0.001, "best_loss": 1e4} + + all_data = torch.load(args.data_file) + x_train, x_val, x_test = all_data + if args.dataset == 'omniglot': + x_train = x_train.to(device) + x_val = x_val.to(device) + x_test = x_test.to(device) + y_size = 1 + y_train = x_train.new_zeros(x_train.size(0), y_size) + y_val = x_train.new_zeros(x_val.size(0), y_size) + y_test = x_train.new_zeros(x_test.size(0), y_size) + + print(torch.__version__) + train_data = torch.utils.data.TensorDataset(x_train, y_train) + val_data = torch.utils.data.TensorDataset(x_val, y_val) + test_data = torch.utils.data.TensorDataset(x_test, y_test) + + train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True) + val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True) + print('Train data: %d batches' % len(train_loader)) + print('Val data: %d batches' % len(val_loader)) + print('Test data: %d batches' % len(test_loader)) + sys.stdout.flush() + + log_niter = len(train_loader) // 5 + + encoder = FlowResNetEncoderV2(args) + decoder = PixelCNNDecoderV2(args) + + vae = VAE(encoder, decoder, args).to(device) + + + if args.eval: + print('begin evaluation') + args.kl_weight = 1 + test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) + vae.load_state_dict(torch.load(args.load_path)) + vae.eval() + with torch.no_grad(): + test(vae, test_loader, "TEST", args) + au, au_var = calc_au(vae, test_loader) + print("%d active units" % au) + # print(au_var) + calc_iwnll(vae, test_loader, args) + return + + enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) + dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) + opt_dict['lr'] = 0.001 + + iter_ = 0 + best_loss = 1e4 + best_kl = best_nll = best_ppl = 0 + decay_cnt = pre_mi = best_mi = mi_not_improved = 0 + vae.train() + start = time.time() + + kl_weight = args.kl_start + anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader)) + + for epoch in range(args.epochs): + + report_kl_loss = report_rec_loss = 0 + report_num_examples = 0 + for datum in train_loader: + batch_data, _ = datum + batch_data = batch_data.to(device) + batch_data = torch.bernoulli(batch_data) + batch_size = batch_data.size(0) + report_num_examples += batch_size + + kl_weight = min(1.0, kl_weight + anneal_rate) + args.kl_weight = kl_weight + + enc_optimizer.zero_grad() + dec_optimizer.zero_grad() + + loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args) + + loss = loss.mean(dim=-1) + + loss.backward() + torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + + enc_optimizer.step() + dec_optimizer.step() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + + if iter_ % log_niter == 0: + + train_loss = (report_rec_loss + report_kl_loss) / report_num_examples + if epoch == 0: + vae.eval() + with torch.no_grad(): + mi = calc_mi(vae, val_loader, device=device) + au, _ = calc_au(vae, val_loader) + + vae.train() + + print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ + 'au %d, time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi, + report_rec_loss / report_num_examples, au, time.time() - start)) + else: + print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ + 'time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_examples, + report_rec_loss / report_num_examples, time.time() - start)) + sys.stdout.flush() + + report_rec_loss = report_kl_loss = 0 + report_num_examples = 0 + + iter_ += 1 + + + + print('kl weight %.4f' % args.kl_weight) + print('epoch: %d, VAL' % epoch) + + vae.eval() + + with torch.no_grad(): + loss, nll, kl = test(vae, val_loader, "VAL", args) + au, au_var = calc_au(vae, val_loader) + print("%d active units" % au) + # print(au_var) + + if loss < best_loss: + print('update best loss') + best_loss = loss + best_nll = nll + best_kl = kl + torch.save(vae.state_dict(), args.save_path) + + if loss > best_loss: + opt_dict["not_improved"] += 1 + if opt_dict["not_improved"] >= decay_epoch: + opt_dict["best_loss"] = loss + opt_dict["not_improved"] = 0 + opt_dict["lr"] = opt_dict["lr"] * lr_decay + vae.load_state_dict(torch.load(args.save_path)) + decay_cnt += 1 + print('new lr: %f' % opt_dict["lr"]) + enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"]) + dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"]) + else: + opt_dict["not_improved"] = 0 + opt_dict["best_loss"] = loss + + if decay_cnt == max_decay: + break + + if epoch % args.test_nepoch == 0: + with torch.no_grad(): + loss, nll, kl = test(vae, test_loader, "TEST", args) + + vae.train() + + # compute importance weighted estimate of log p(x) + vae.load_state_dict(torch.load(args.save_path)) + vae.eval() + with torch.no_grad(): + loss, nll, kl = test(vae, test_loader, "TEST", args) + au, au_var = calc_au(vae, test_loader) + print("%d active units" % au) + # print(au_var) + + test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) + + with torch.no_grad(): + calc_iwnll(vae, test_loader, args) + + +if __name__ == '__main__': + args = init_config() + if not args.eval: + sys.stdout = Logger(args.log_path) + main(args) diff --git a/image_ss_omniglot.py b/image_ss_omniglot.py new file mode 100644 index 0000000..509e02f --- /dev/null +++ b/image_ss_omniglot.py @@ -0,0 +1,311 @@ +import os +import time +import importlib +import argparse +import sys +import numpy as np + +import torch +from torch import nn, optim + +from modules import ResNetEncoderV2,BNResNetEncoderV2, PixelCNNDecoderV2,FlowResNetEncoderV2 +from modules import VAE, LinearDiscriminator_only +from logger import Logger +from omniglotDataset import Omniglot + + +# Junxian's new parameters +clip_grad = 1.0 +decay_epoch = 2 +lr_decay = 0.8 +max_decay = 5 + + +def init_config(): + parser = argparse.ArgumentParser(description='VAE mode collapse study') + parser.add_argument('--gamma', type=float, default=0.0) + parser.add_argument('--gamma_type', type=str, default='BN') + parser.add_argument('--gamma_train', action="store_true", default=False) + + # model hyperparameters + parser.add_argument('--delta', type=float, default=0.0) + parser.add_argument('--dataset', default='omniglot', type=str, help='dataset to use') + # optimization parameters + parser.add_argument('--momentum', type=float, default=0.9, help='sgd momentum') + parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="adam", help='sgd momentum') + + parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') + parser.add_argument('--iw_nsamples', type=int, default=500, + help='number of samples to compute importance weighted estimate') + + # select mode + parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') + parser.add_argument('--load_path', type=str, + default='models/mnist/test/model.pt') # TODO: 设定load_path + + # annealing paramters + parser.add_argument('--warm_up', type=int, default=100, + help="number of annealing epochs. warm_up=0 means not anneal") + parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight") + + # output directory + parser.add_argument('--exp_dir', default=None, type=str, + help='experiment directory.') + parser.add_argument("--save_ckpt", type=int, default=0, + help="save checkpoint every epoch before this number") + parser.add_argument("--save_latent", type=int, default=0) + + # new + parser.add_argument("--reset_dec", action="store_true", default=True) + parser.add_argument("--load_best_epoch", type=int, default=0) + parser.add_argument("--lr", type=float, default=1.) + + parser.add_argument("--fb", type=int, default=0, + help="0: no fb; 1: fb; E") + parser.add_argument("--target_kl", type=float, default=-1, + help="target kl of the free bits trick") + + parser.add_argument("--batch_size", type=int, default=50, + help="number of epochs") + parser.add_argument("--epochs", type=int, default=300, + help="number of epochs") + parser.add_argument("--num_label", type=int, default=100, + help="t") + parser.add_argument("--freeze_enc", action="store_true", + default=True) # True-> freeze the parameters of vae.encoder + parser.add_argument("--discriminator", type=str, default="linear") + + parser.add_argument('--taskid', type=int, default=0, help='slurm task id') + parser.add_argument('--device', type=str, default="cpu") + parser.add_argument('--delta_rate', type=float, default=0.0, + help=" coontrol the minization of the variation of latent variables") + + parser.add_argument("--nz_new", type=int, default=32) # myGaussianLSTMencoder + + parser.add_argument('--IAF', action='store_true', default=False) + parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow") + parser.add_argument('--flow_width', type=int, default=60, help="width of flow") + parser.add_argument('--p_drop', type=float, default=0) # p \in [0, 1] + + args = parser.parse_args() + + # args.load_path ='models/omniglot/omniglot_aggressive1_KL1.00_dr0.00_beta-1.00_nz32_0_0_783435_betaF_4/model.pt' + # args.load_path ='models/omniglot/omniglot_aggressive0_KL0.00_warm10_gamma0.50_BN_train5_dr1.00_beta-1.00_nz32_drop0.15_0_0_783435_betaF_5_large_de20/model.pt' + # args.gamma = 0.5 + # args.gamma_type = 'BN' + # args.load_path = 'models/omniglot/omniglot_fb2_tr0.20_fd2_fw60_dr0.00_nz32_0_0_783435_IAF/model.pt' + # args.IAF = True + + # if len(args.load_path)>0: + # args.load_path = 'models/'+args.dataset+'/'+args.load_path + + # set args.cuda + if 'cuda' in args.device: + args.cuda = True + else: + args.cuda = False + + # set seeds + seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] + args.seed = seed_set[args.taskid] + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + if args.cuda: + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + + config_file = "config.config_%s_ss" % args.dataset + params = importlib.import_module(config_file).params + args = argparse.Namespace(**vars(args), **params) + + load_str = "_load" if args.load_path != "" else "" + opt_str = "_adam" if args.opt == "adam" else "_sgd" + nlabel_str = "_nlabel{}".format(args.num_label) + freeze_str = "_freeze" if args.freeze_enc else "" + + if len(args.load_path.split("/")) > 2: + load_path_str = args.load_path.split("/")[2] + else: + load_path_str = args.load_path.split("/")[1] + + model_str = "_{}".format(args.discriminator) + # set load and save paths + if args.exp_dir == None: + args.exp_dir = "models/exp_{}{}_ss_ft/{}{}{}{}{}".format(args.dataset, + load_str, load_path_str, model_str, opt_str, + nlabel_str, freeze_str) + if not os.path.exists(args.exp_dir): + os.makedirs(args.exp_dir) + args.log_path = os.path.join(args.exp_dir, 'log.txt') + + # set args.label + if 'label' in params: + args.label = params['label'] + else: + args.label = False + + args.kl_weight = 1 + + return args + + +def test(model, test_loader, mode, args, verbose=False): + + report_correct = report_loss = 0 + report_num_sents = 0 + + N=0 + for datum in test_loader: + batch_data, batch_labels = datum + batch_data = batch_data.to(args.device) + batch_labels = batch_labels.to(args.device).squeeze() + #batch_data = torch.bernoulli(batch_data) + batch_size = batch_data.size(0) + + # not predict start symbol + report_num_sents += batch_size + loss, correct = model.get_performance_with_feature(batch_data, batch_labels) + + loss = loss.sum() + + report_loss += loss.item() + report_correct += correct + N+=1 + + test_loss = report_loss / report_num_sents + acc = report_correct / report_num_sents + + if verbose: + print('%s --- avg_loss: %.4f, acc: %.4f' % \ + (mode, test_loss, acc)) + # sys.stdout.flush() + + return test_loss, acc + +def train(args,dataset:Omniglot,task,device,trainum=10): + x_train,l_train,x_test,l_test,NC = dataset.load_task(task,trainum) + train_data = torch.utils.data.TensorDataset(x_train, l_train) + test_data = torch.utils.data.TensorDataset(x_test, l_test) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False) + print('Train data: %d samples' % len(x_train)) + log_niter = max(1, (len(train_data) // (args.batch_size)) // 10) + + + if args.discriminator == "linear": + discriminator = LinearDiscriminator_only(args, NC).to(device) + + + if args.opt == "sgd": + optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) + elif args.opt == "adam": + optimizer = optim.Adam(discriminator.parameters(), lr=0.001) + # optimizer = swats.SWATS(discriminator.parameters(), lr=0.001) + + else: + raise ValueError("optimizer not supported") + + discriminator.train() + + iter_ = 0 + acc_loss = 0. + for epoch in range(args.epochs): + report_loss = 0 + report_correct = report_num_sents = 0 + acc_batch_size = 0 + optimizer.zero_grad() + for datum in train_loader: + batch_data, batch_labels = datum + batch_data = batch_data.to(device) + batch_labels = batch_labels.to(device).squeeze() + batch_size = batch_data.size(0) + if batch_data.size(0) < 2: + continue + + # not predict start symbol + report_num_sents += batch_size + acc_batch_size += batch_size + + # (batch_size) + loss, correct = discriminator.get_performance_with_feature(batch_data, batch_labels) + + acc_loss = loss.sum() + acc_loss.backward() + optimizer.step() + + optimizer.zero_grad() + + report_loss += loss.sum().item() + report_correct += correct + iter_ += 1 + + discriminator.eval() + with torch.no_grad(): + loss, acc = test(discriminator, test_loader, "VAL", args, verbose=False) + discriminator.train() + + # discriminator.load_state_dict(torch.load(args.save_path)) + discriminator.eval() + with torch.no_grad(): + loss, acc = test(discriminator, test_loader, "TEST", args,verbose=True) + return loss, acc,NC + +def main(args): + if args.cuda: + print('using cuda') + print(str(args)) + + device = args.device + + + if args.gamma > 0 and not args.IAF: + encoder = BNResNetEncoderV2(args) + elif not args.IAF: + encoder = ResNetEncoderV2(args) + elif args.IAF: + encoder = FlowResNetEncoderV2(args) + + decoder = PixelCNNDecoderV2(args,mode='large') # if args.HKL == 'H': + + vae = VAE(encoder, decoder, args).to(device) + + if args.load_path: + loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device)) + vae.load_state_dict(loaded_state_dict) + print("%s loaded" % args.load_path) + vae.eval() + + dataset = Omniglot(args.root, encoder=vae.encoder, device=device, IAF=args.IAF) + + for tasknum in [5,10,15]: + acc_sum=0 + acc_wsum=0 + N=0 + acclist=[] + NClist=[] + for task in range(50): + N+=1 + loss, acc,NC= train(args,dataset,task,device,tasknum) + acc_sum+=acc + acc_wsum+=NC*acc + acclist.append(acc) + NClist.append(NC) + acc_mean = acc_sum/N + acc_wmean = acc_wsum/sum(NClist) + + print('train_num', tasknum,'acc',acc_mean, acc_wmean) + print(acclist) + # plt.plot(range(50),acclist,label = 'trainnum%d'%tasknum) + # plt.show() + # plt.legend() + + + + + +if __name__ == '__main__': + args = init_config() + sys.stdout = Logger(args.log_path) + print('---------------') + main(args) diff --git a/logger.py b/logger.py new file mode 100755 index 0000000..04cd066 --- /dev/null +++ b/logger.py @@ -0,0 +1,17 @@ +""" +Logger class files +""" +import sys + +class Logger(object): + def __init__(self, output_file): + self.terminal = sys.stdout + self.log = open(output_file, "w") + + def write(self, message): + print(message, end="", file=self.terminal, flush=True) + print(message, end="", file=self.log, flush=True) + + def flush(self): + self.terminal.flush() + self.log.flush() diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000..8d70105 Binary files /dev/null and b/models/.DS_Store differ diff --git a/models/omniglot/.DS_Store b/models/omniglot/.DS_Store new file mode 100644 index 0000000..47f6cbb Binary files /dev/null and b/models/omniglot/.DS_Store differ diff --git a/models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt b/models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt new file mode 100644 index 0000000..fee3ace --- /dev/null +++ b/models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt @@ -0,0 +1,6 @@ +Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=0.0, device='cpu', enc_layers=[64, 64, 64], epochs=500, eval=False, gamma=0.0, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, kl_weight=1, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/model.pt', seed=783435, taskid=0, test_nepoch=5, warm_up=10) +1.5.0 +Train data: 447 batches +Val data: 40 batches +Test data: 162 batches +epoch: 0, iter: 0, avg_loss: 537.0994, kl: 30.8722, mi: 0.3572, recon: 506.2273,au 32, time elapsed 13.60s diff --git a/models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt b/models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt new file mode 100644 index 0000000..382017f --- /dev/null +++ b/models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt @@ -0,0 +1,6 @@ +Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=1.0, device='cpu', enc_layers=[64, 64, 64], epochs=500, eval=False, gamma=0.5, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, kl_weight=1, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0.2, reset_dec=False, save_path='models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/model.pt', seed=783435, taskid=0, test_nepoch=5, warm_up=10) +1.5.0 +Train data: 447 batches +Val data: 40 batches +Test data: 162 batches +epoch: 0, iter: 0, avg_loss: 528.2644, kl: 21.9613, mi: 0.1018, recon: 506.3031,au 0, time elapsed 12.70s diff --git a/models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt b/models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt new file mode 100644 index 0000000..6760dac --- /dev/null +++ b/models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt @@ -0,0 +1,12 @@ +Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=0.0, device='cpu', drop_start=1.0, enc_layers=[64, 64, 64], epochs=500, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.0, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/model.pt', seed=783435, target_kl=0.15, taskid=0, test_nepoch=5, warm_up=10) +1.5.0 +Train data: 447 batches +Val data: 40 batches +Test data: 162 batches +> /Users/shendazhong/Desktop/AAAI21/code_reference/Du-VAE/modules/encoders/enc_flow.py(64)encode() + 63  +---> 64  return z_T, kl.sum(dim=[1, 2]) # like KL + 65  + +ipdb> --KeyboardInterrupt-- +ipdb> \ No newline at end of file diff --git a/models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt b/models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt new file mode 100644 index 0000000..bcb0054 --- /dev/null +++ b/models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt @@ -0,0 +1,6 @@ +Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=0.0, device='cpu', drop_start=1.0, enc_layers=[64, 64, 64], epochs=500, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.0, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_nepoch=5, warm_up=10) +1.5.0 +Train data: 447 batches +Val data: 40 batches +Test data: 162 batches +epoch: 0, iter: 0, avg_loss: 571.8783, kl: 24.7818, mi: 1.0189, recon: 547.0965,au 32, time elapsed 12.62s diff --git a/models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt b/models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt new file mode 100644 index 0000000..87581dd --- /dev/null +++ b/models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt @@ -0,0 +1,6 @@ +Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=1.0, device='cpu', drop_start=1.0, enc_layers=[64, 64, 64], epochs=500, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.5, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0.15, reset_dec=False, save_path='models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_nepoch=5, warm_up=10) +1.5.0 +Train data: 447 batches +Val data: 40 batches +Test data: 162 batches +epoch: 0, iter: 0, avg_loss: 564.9081, kl: 17.1125, mi: 0.4995, recon: 547.7956,au 0, time elapsed 11.56s diff --git a/models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt new file mode 100644 index 0000000..eca0b9d --- /dev/null +++ b/models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt @@ -0,0 +1,6 @@ +Namespace(batch_size=32, cuda=False, dataset='short_yelp', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=512, dec_type='lstm', delta_rate=1, device='cpu', enc_nh=512, enc_type='lstm', epochs=100, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=0.0, label=True, load_path='', log_niter=50, log_path='models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=128, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/short_yelp_data/short_yelp.test.txt', test_nepoch=5, train_data='data/short_yelp_data/short_yelp.train.txt', val_data='data/short_yelp_data/short_yelp.valid.txt', vocab_file='data/short_yelp_data/vocab.txt', warm_up=10) +data/short_yelp_data/vocab.txt +Train data: 100000 samples +finish reading datasets, vocab size is 8411 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 151.0429, kl/H(z|x): 15.4808, mi: 0.0923, recon: 135.5621,au 0, time elapsed 40.76s diff --git a/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt new file mode 100644 index 0000000..68c778f --- /dev/null +++ b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt @@ -0,0 +1,743 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=0.0, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 76.0298, kl/H(z|x): 0.0000, mi: 0.0056, recon: 76.0298,au 0, time elapsed 3.91s +epoch: 0, iter: 50, avg_loss: 66.4376, kl/H(z|x): 0.0000, mi: 0.0607, recon: 66.4376,au 0, time elapsed 8.59s +epoch: 0, iter: 100, avg_loss: 60.9033, kl/H(z|x): 0.0000, mi: -0.0787, recon: 60.9032,au 0, time elapsed 13.32s +epoch: 0, iter: 150, avg_loss: 59.8395, kl/H(z|x): 0.0018, mi: 0.0027, recon: 59.8377,au 0, time elapsed 18.00s +epoch: 0, iter: 200, avg_loss: 57.6219, kl/H(z|x): 0.0059, mi: 0.0679, recon: 57.6160,au 0, time elapsed 22.66s +epoch: 0, iter: 250, avg_loss: 54.6628, kl/H(z|x): 0.0098, mi: 0.1069, recon: 54.6529,au 0, time elapsed 27.49s +epoch: 0, iter: 300, avg_loss: 53.0137, kl/H(z|x): 0.0110, mi: 0.0419, recon: 53.0027,au 0, time elapsed 32.34s +epoch: 0, iter: 350, avg_loss: 51.6142, kl/H(z|x): 0.0112, mi: 0.0623, recon: 51.6030,au 0, time elapsed 37.05s +epoch: 0, iter: 400, avg_loss: 51.3852, kl/H(z|x): 0.0135, mi: -0.0299, recon: 51.3717,au 0, time elapsed 41.78s +epoch: 0, iter: 450, avg_loss: 50.7162, kl/H(z|x): 0.0169, mi: 0.0734, recon: 50.6993,au 0, time elapsed 46.49s +kl weight 1.0000 +VAL --- avg_loss: 46.5658, kl/H(z|x): 0.0082, mi: -0.0036, recon: 46.5576, nll: 46.5658, ppl: 68.9411 +0 active units +update best loss +TEST --- avg_loss: 46.5894, kl/H(z|x): 0.0082, mi: -0.0030, recon: 46.5813, nll: 46.5894, ppl: 69.0896 +epoch: 1, iter: 500, avg_loss: 46.1078, kl/H(z|x): 0.0081, recon: 46.0997,time elapsed 56.18s +epoch: 1, iter: 550, avg_loss: 48.3830, kl/H(z|x): 0.0233, recon: 48.3596,time elapsed 57.28s +epoch: 1, iter: 600, avg_loss: 48.6702, kl/H(z|x): 0.0146, recon: 48.6555,time elapsed 58.38s +epoch: 1, iter: 650, avg_loss: 47.8828, kl/H(z|x): 0.0140, recon: 47.8688,time elapsed 59.44s +epoch: 1, iter: 700, avg_loss: 47.5835, kl/H(z|x): 0.0077, recon: 47.5758,time elapsed 60.52s +epoch: 1, iter: 750, avg_loss: 47.6507, kl/H(z|x): 0.0190, recon: 47.6317,time elapsed 61.58s +epoch: 1, iter: 800, avg_loss: 47.6542, kl/H(z|x): 0.0161, recon: 47.6381,time elapsed 62.70s +epoch: 1, iter: 850, avg_loss: 47.3931, kl/H(z|x): 0.0086, recon: 47.3845,time elapsed 63.78s +epoch: 1, iter: 900, avg_loss: 47.7290, kl/H(z|x): 0.0088, recon: 47.7202,time elapsed 64.83s +epoch: 1, iter: 950, avg_loss: 47.1321, kl/H(z|x): 0.0130, recon: 47.1190,time elapsed 65.92s +kl weight 1.0000 +VAL --- avg_loss: 45.2807, kl/H(z|x): 0.0073, mi: -0.0388, recon: 45.2735, nll: 45.2807, ppl: 61.3398 +0 active units +update best loss +TEST --- avg_loss: 45.2893, kl/H(z|x): 0.0073, mi: 0.0676, recon: 45.2821, nll: 45.2893, ppl: 61.3879 +epoch: 2, iter: 1000, avg_loss: 45.6955, kl/H(z|x): 0.0072, recon: 45.6883,time elapsed 75.45s +epoch: 2, iter: 1050, avg_loss: 46.7630, kl/H(z|x): 0.0097, recon: 46.7533,time elapsed 76.54s +epoch: 2, iter: 1100, avg_loss: 47.1458, kl/H(z|x): 0.0081, recon: 47.1377,time elapsed 77.67s +epoch: 2, iter: 1150, avg_loss: 47.0823, kl/H(z|x): 0.0073, recon: 47.0751,time elapsed 78.73s +epoch: 2, iter: 1200, avg_loss: 46.7632, kl/H(z|x): 0.0055, recon: 46.7577,time elapsed 79.80s +epoch: 2, iter: 1250, avg_loss: 47.3431, kl/H(z|x): 0.0089, recon: 47.3342,time elapsed 80.87s +epoch: 2, iter: 1300, avg_loss: 46.7255, kl/H(z|x): 0.0058, recon: 46.7197,time elapsed 81.92s +epoch: 2, iter: 1350, avg_loss: 46.7585, kl/H(z|x): 0.0034, recon: 46.7551,time elapsed 83.03s +epoch: 2, iter: 1400, avg_loss: 45.8850, kl/H(z|x): 0.0032, recon: 45.8818,time elapsed 84.11s +epoch: 2, iter: 1450, avg_loss: 46.3266, kl/H(z|x): 0.0028, recon: 46.3237,time elapsed 85.29s +kl weight 1.0000 +VAL --- avg_loss: 44.1531, kl/H(z|x): 0.0026, mi: -0.1020, recon: 44.1505, nll: 44.1531, ppl: 55.3633 +0 active units +update best loss +TEST --- avg_loss: 44.1502, kl/H(z|x): 0.0026, mi: 0.0268, recon: 44.1476, nll: 44.1502, ppl: 55.3488 +epoch: 3, iter: 1500, avg_loss: 45.4690, kl/H(z|x): 0.0026, recon: 45.4664,time elapsed 94.90s +epoch: 3, iter: 1550, avg_loss: 46.3678, kl/H(z|x): 0.0089, recon: 46.3588,time elapsed 95.99s +epoch: 3, iter: 1600, avg_loss: 46.3185, kl/H(z|x): 0.0036, recon: 46.3148,time elapsed 97.05s +epoch: 3, iter: 1650, avg_loss: 46.1619, kl/H(z|x): 0.0050, recon: 46.1569,time elapsed 98.13s +epoch: 3, iter: 1700, avg_loss: 46.0672, kl/H(z|x): 0.0035, recon: 46.0636,time elapsed 99.21s +epoch: 3, iter: 1750, avg_loss: 46.3371, kl/H(z|x): 0.0024, recon: 46.3347,time elapsed 100.29s +epoch: 3, iter: 1800, avg_loss: 46.4199, kl/H(z|x): 0.0049, recon: 46.4150,time elapsed 101.36s +epoch: 3, iter: 1850, avg_loss: 45.9181, kl/H(z|x): 0.0044, recon: 45.9137,time elapsed 102.49s +epoch: 3, iter: 1900, avg_loss: 46.0571, kl/H(z|x): 0.0087, recon: 46.0484,time elapsed 103.54s +epoch: 3, iter: 1950, avg_loss: 45.8462, kl/H(z|x): 0.0054, recon: 45.8408,time elapsed 104.61s +kl weight 1.0000 +VAL --- avg_loss: 44.0262, kl/H(z|x): 0.0022, mi: -0.0017, recon: 44.0241, nll: 44.0262, ppl: 54.7285 +0 active units +update best loss +TEST --- avg_loss: 44.0316, kl/H(z|x): 0.0022, mi: -0.1088, recon: 44.0294, nll: 44.0316, ppl: 54.7551 +epoch: 4, iter: 2000, avg_loss: 43.8706, kl/H(z|x): 0.0022, recon: 43.8684,time elapsed 114.37s +epoch: 4, iter: 2050, avg_loss: 45.8839, kl/H(z|x): 0.0051, recon: 45.8788,time elapsed 115.43s +epoch: 4, iter: 2100, avg_loss: 46.1152, kl/H(z|x): 0.0088, recon: 46.1065,time elapsed 116.49s +epoch: 4, iter: 2150, avg_loss: 46.1732, kl/H(z|x): 0.0055, recon: 46.1677,time elapsed 117.59s +epoch: 4, iter: 2200, avg_loss: 46.1230, kl/H(z|x): 0.0056, recon: 46.1174,time elapsed 118.66s +epoch: 4, iter: 2250, avg_loss: 45.9930, kl/H(z|x): 0.0097, recon: 45.9833,time elapsed 119.75s +epoch: 4, iter: 2300, avg_loss: 46.5047, kl/H(z|x): 0.0126, recon: 46.4921,time elapsed 120.81s +epoch: 4, iter: 2350, avg_loss: 46.0809, kl/H(z|x): 0.0108, recon: 46.0701,time elapsed 121.87s +epoch: 4, iter: 2400, avg_loss: 46.5641, kl/H(z|x): 0.0058, recon: 46.5583,time elapsed 122.97s +epoch: 4, iter: 2450, avg_loss: 46.0603, kl/H(z|x): 0.0050, recon: 46.0553,time elapsed 124.03s +kl weight 1.0000 +VAL --- avg_loss: 44.3252, kl/H(z|x): 0.0032, mi: 0.0661, recon: 44.3220, nll: 44.3252, ppl: 56.2365 +0 active units +TEST --- avg_loss: 44.3159, kl/H(z|x): 0.0032, mi: 0.0157, recon: 44.3127, nll: 44.3159, ppl: 56.1890 +epoch: 5, iter: 2500, avg_loss: 46.6240, kl/H(z|x): 0.0032, recon: 46.6208,time elapsed 133.62s +epoch: 5, iter: 2550, avg_loss: 45.8931, kl/H(z|x): 0.0056, recon: 45.8875,time elapsed 134.71s +epoch: 5, iter: 2600, avg_loss: 45.6455, kl/H(z|x): 0.0056, recon: 45.6399,time elapsed 135.79s +epoch: 5, iter: 2650, avg_loss: 45.9743, kl/H(z|x): 0.0043, recon: 45.9701,time elapsed 136.84s +epoch: 5, iter: 2700, avg_loss: 45.7565, kl/H(z|x): 0.0045, recon: 45.7520,time elapsed 137.95s +epoch: 5, iter: 2750, avg_loss: 45.9835, kl/H(z|x): 0.0050, recon: 45.9785,time elapsed 139.02s +epoch: 5, iter: 2800, avg_loss: 45.7868, kl/H(z|x): 0.0041, recon: 45.7827,time elapsed 140.11s +epoch: 5, iter: 2850, avg_loss: 45.7049, kl/H(z|x): 0.0042, recon: 45.7007,time elapsed 141.18s +epoch: 5, iter: 2900, avg_loss: 46.0195, kl/H(z|x): 0.0034, recon: 46.0161,time elapsed 142.27s +epoch: 5, iter: 2950, avg_loss: 46.1072, kl/H(z|x): 0.0051, recon: 46.1022,time elapsed 143.38s +kl weight 1.0000 +VAL --- avg_loss: 43.8109, kl/H(z|x): 0.0038, mi: -0.0220, recon: 43.8071, nll: 43.8109, ppl: 53.6676 +0 active units +update best loss +TEST --- avg_loss: 43.8015, kl/H(z|x): 0.0038, mi: -0.0332, recon: 43.7977, nll: 43.8015, ppl: 53.6218 +epoch: 6, iter: 3000, avg_loss: 43.1108, kl/H(z|x): 0.0038, recon: 43.1070,time elapsed 153.22s +epoch: 6, iter: 3050, avg_loss: 45.8055, kl/H(z|x): 0.0040, recon: 45.8015,time elapsed 154.31s +epoch: 6, iter: 3100, avg_loss: 45.9045, kl/H(z|x): 0.0038, recon: 45.9007,time elapsed 155.38s +epoch: 6, iter: 3150, avg_loss: 45.1809, kl/H(z|x): 0.0025, recon: 45.1784,time elapsed 156.45s +epoch: 6, iter: 3200, avg_loss: 45.6232, kl/H(z|x): 0.0026, recon: 45.6206,time elapsed 157.51s +epoch: 6, iter: 3250, avg_loss: 45.1418, kl/H(z|x): 0.0020, recon: 45.1398,time elapsed 158.63s +epoch: 6, iter: 3300, avg_loss: 45.6125, kl/H(z|x): 0.0031, recon: 45.6094,time elapsed 159.70s +epoch: 6, iter: 3350, avg_loss: 45.3761, kl/H(z|x): 0.0019, recon: 45.3742,time elapsed 160.77s +epoch: 6, iter: 3400, avg_loss: 45.3928, kl/H(z|x): 0.0022, recon: 45.3906,time elapsed 161.88s +epoch: 6, iter: 3450, avg_loss: 45.8120, kl/H(z|x): 0.0024, recon: 45.8096,time elapsed 162.97s +kl weight 1.0000 +VAL --- avg_loss: 43.5564, kl/H(z|x): 0.0021, mi: 0.0195, recon: 43.5543, nll: 43.5564, ppl: 52.4400 +0 active units +update best loss +TEST --- avg_loss: 43.5776, kl/H(z|x): 0.0021, mi: 0.0936, recon: 43.5755, nll: 43.5776, ppl: 52.5412 +epoch: 7, iter: 3500, avg_loss: 44.6366, kl/H(z|x): 0.0021, recon: 44.6345,time elapsed 172.75s +epoch: 7, iter: 3550, avg_loss: 45.0213, kl/H(z|x): 0.0021, recon: 45.0192,time elapsed 173.83s +epoch: 7, iter: 3600, avg_loss: 45.6376, kl/H(z|x): 0.0029, recon: 45.6347,time elapsed 174.87s +epoch: 7, iter: 3650, avg_loss: 45.7895, kl/H(z|x): 0.0036, recon: 45.7860,time elapsed 175.95s +epoch: 7, iter: 3700, avg_loss: 45.5016, kl/H(z|x): 0.0038, recon: 45.4978,time elapsed 177.00s +epoch: 7, iter: 3750, avg_loss: 45.6823, kl/H(z|x): 0.0032, recon: 45.6791,time elapsed 178.11s +epoch: 7, iter: 3800, avg_loss: 45.2862, kl/H(z|x): 0.0034, recon: 45.2827,time elapsed 179.19s +epoch: 7, iter: 3850, avg_loss: 45.3230, kl/H(z|x): 0.0029, recon: 45.3201,time elapsed 180.27s +epoch: 7, iter: 3900, avg_loss: 45.1383, kl/H(z|x): 0.0025, recon: 45.1358,time elapsed 181.35s +epoch: 7, iter: 3950, avg_loss: 45.4385, kl/H(z|x): 0.0027, recon: 45.4358,time elapsed 182.47s +kl weight 1.0000 +VAL --- avg_loss: 43.6658, kl/H(z|x): 0.0022, mi: 0.0445, recon: 43.6635, nll: 43.6658, ppl: 52.9641 +0 active units +TEST --- avg_loss: 43.6665, kl/H(z|x): 0.0022, mi: 0.0049, recon: 43.6643, nll: 43.6665, ppl: 52.9676 +epoch: 8, iter: 4000, avg_loss: 46.5223, kl/H(z|x): 0.0022, recon: 46.5200,time elapsed 191.89s +epoch: 8, iter: 4050, avg_loss: 44.8934, kl/H(z|x): 0.0035, recon: 44.8899,time elapsed 193.07s +epoch: 8, iter: 4100, avg_loss: 45.4792, kl/H(z|x): 0.0068, recon: 45.4724,time elapsed 194.14s +epoch: 8, iter: 4150, avg_loss: 45.5920, kl/H(z|x): 0.0026, recon: 45.5894,time elapsed 195.25s +epoch: 8, iter: 4200, avg_loss: 45.3591, kl/H(z|x): 0.0034, recon: 45.3557,time elapsed 196.38s +epoch: 8, iter: 4250, avg_loss: 45.7200, kl/H(z|x): 0.0038, recon: 45.7162,time elapsed 197.49s +epoch: 8, iter: 4300, avg_loss: 45.4247, kl/H(z|x): 0.0022, recon: 45.4225,time elapsed 198.56s +epoch: 8, iter: 4350, avg_loss: 45.4409, kl/H(z|x): 0.0042, recon: 45.4367,time elapsed 199.64s +epoch: 8, iter: 4400, avg_loss: 45.1211, kl/H(z|x): 0.0025, recon: 45.1186,time elapsed 200.72s +epoch: 8, iter: 4450, avg_loss: 45.2950, kl/H(z|x): 0.0024, recon: 45.2926,time elapsed 201.78s +kl weight 1.0000 +VAL --- avg_loss: 43.5856, kl/H(z|x): 0.0023, mi: 0.1106, recon: 43.5833, nll: 43.5856, ppl: 52.5797 +0 active units +TEST --- avg_loss: 43.6012, kl/H(z|x): 0.0023, mi: -0.0433, recon: 43.5989, nll: 43.6012, ppl: 52.6544 +epoch: 9, iter: 4500, avg_loss: 44.3652, kl/H(z|x): 0.0023, recon: 44.3629,time elapsed 211.64s +epoch: 9, iter: 4550, avg_loss: 45.4353, kl/H(z|x): 0.0063, recon: 45.4290,time elapsed 212.78s +epoch: 9, iter: 4600, avg_loss: 45.5352, kl/H(z|x): 0.0076, recon: 45.5276,time elapsed 213.88s +epoch: 9, iter: 4650, avg_loss: 44.9239, kl/H(z|x): 0.0022, recon: 44.9218,time elapsed 214.96s +epoch: 9, iter: 4700, avg_loss: 44.9674, kl/H(z|x): 0.0025, recon: 44.9649,time elapsed 216.06s +epoch: 9, iter: 4750, avg_loss: 45.3179, kl/H(z|x): 0.0024, recon: 45.3154,time elapsed 217.15s +epoch: 9, iter: 4800, avg_loss: 45.2692, kl/H(z|x): 0.0022, recon: 45.2670,time elapsed 218.25s +epoch: 9, iter: 4850, avg_loss: 45.2401, kl/H(z|x): 0.0017, recon: 45.2384,time elapsed 219.32s +epoch: 9, iter: 4900, avg_loss: 45.3045, kl/H(z|x): 0.0019, recon: 45.3026,time elapsed 220.39s +epoch: 9, iter: 4950, avg_loss: 45.3250, kl/H(z|x): 0.0023, recon: 45.3228,time elapsed 221.44s +kl weight 1.0000 +VAL --- avg_loss: 43.6818, kl/H(z|x): 0.0018, mi: -0.1053, recon: 43.6799, nll: 43.6818, ppl: 53.0412 +0 active units +TEST --- avg_loss: 43.6995, kl/H(z|x): 0.0018, mi: 0.1338, recon: 43.6977, nll: 43.6995, ppl: 53.1268 +epoch: 10, iter: 5000, avg_loss: 44.2786, kl/H(z|x): 0.0019, recon: 44.2768,time elapsed 230.94s +epoch: 10, iter: 5050, avg_loss: 45.4548, kl/H(z|x): 0.0018, recon: 45.4530,time elapsed 232.11s +epoch: 10, iter: 5100, avg_loss: 45.1848, kl/H(z|x): 0.0019, recon: 45.1830,time elapsed 233.17s +epoch: 10, iter: 5150, avg_loss: 44.9300, kl/H(z|x): 0.0021, recon: 44.9279,time elapsed 234.25s +epoch: 10, iter: 5200, avg_loss: 45.4789, kl/H(z|x): 0.0024, recon: 45.4765,time elapsed 235.36s +epoch: 10, iter: 5250, avg_loss: 45.0654, kl/H(z|x): 0.0024, recon: 45.0630,time elapsed 236.44s +epoch: 10, iter: 5300, avg_loss: 45.7442, kl/H(z|x): 0.0025, recon: 45.7417,time elapsed 237.52s +epoch: 10, iter: 5350, avg_loss: 45.4598, kl/H(z|x): 0.0032, recon: 45.4566,time elapsed 238.61s +epoch: 10, iter: 5400, avg_loss: 45.2135, kl/H(z|x): 0.0019, recon: 45.2116,time elapsed 239.67s +epoch: 10, iter: 5450, avg_loss: 44.7176, kl/H(z|x): 0.0018, recon: 44.7159,time elapsed 240.73s +kl weight 1.0000 +VAL --- avg_loss: 43.5995, kl/H(z|x): 0.0014, mi: 0.0867, recon: 43.5981, nll: 43.5995, ppl: 52.6458 +0 active units +TEST --- avg_loss: 43.5949, kl/H(z|x): 0.0014, mi: -0.0502, recon: 43.5935, nll: 43.5949, ppl: 52.6239 +epoch: 11, iter: 5500, avg_loss: 43.6854, kl/H(z|x): 0.0014, recon: 43.6840,time elapsed 250.17s +epoch: 11, iter: 5550, avg_loss: 44.8833, kl/H(z|x): 0.0018, recon: 44.8815,time elapsed 251.27s +epoch: 11, iter: 5600, avg_loss: 45.3650, kl/H(z|x): 0.0018, recon: 45.3632,time elapsed 252.37s +epoch: 11, iter: 5650, avg_loss: 45.1862, kl/H(z|x): 0.0023, recon: 45.1839,time elapsed 253.52s +epoch: 11, iter: 5700, avg_loss: 45.1249, kl/H(z|x): 0.0020, recon: 45.1229,time elapsed 254.60s +epoch: 11, iter: 5750, avg_loss: 45.1719, kl/H(z|x): 0.0020, recon: 45.1699,time elapsed 255.67s +epoch: 11, iter: 5800, avg_loss: 44.8795, kl/H(z|x): 0.0018, recon: 44.8778,time elapsed 256.74s +epoch: 11, iter: 5850, avg_loss: 44.8756, kl/H(z|x): 0.0020, recon: 44.8735,time elapsed 257.85s +epoch: 11, iter: 5900, avg_loss: 45.5389, kl/H(z|x): 0.0023, recon: 45.5366,time elapsed 258.92s +epoch: 11, iter: 5950, avg_loss: 44.9864, kl/H(z|x): 0.0022, recon: 44.9843,time elapsed 260.01s +kl weight 1.0000 +VAL --- avg_loss: 43.4362, kl/H(z|x): 0.0013, mi: -0.0682, recon: 43.4349, nll: 43.4362, ppl: 51.8702 +0 active units +update best loss +TEST --- avg_loss: 43.4197, kl/H(z|x): 0.0013, mi: 0.0301, recon: 43.4184, nll: 43.4197, ppl: 51.7926 +epoch: 12, iter: 6000, avg_loss: 42.4735, kl/H(z|x): 0.0013, recon: 42.4721,time elapsed 269.71s +epoch: 12, iter: 6050, avg_loss: 44.9651, kl/H(z|x): 0.0019, recon: 44.9631,time elapsed 270.82s +epoch: 12, iter: 6100, avg_loss: 45.2931, kl/H(z|x): 0.0020, recon: 45.2912,time elapsed 271.90s +epoch: 12, iter: 6150, avg_loss: 45.0727, kl/H(z|x): 0.0021, recon: 45.0707,time elapsed 272.95s +epoch: 12, iter: 6200, avg_loss: 44.5651, kl/H(z|x): 0.0017, recon: 44.5635,time elapsed 274.02s +epoch: 12, iter: 6250, avg_loss: 44.6325, kl/H(z|x): 0.0017, recon: 44.6308,time elapsed 275.08s +epoch: 12, iter: 6300, avg_loss: 45.3624, kl/H(z|x): 0.0018, recon: 45.3605,time elapsed 276.14s +epoch: 12, iter: 6350, avg_loss: 45.1114, kl/H(z|x): 0.0024, recon: 45.1090,time elapsed 277.23s +epoch: 12, iter: 6400, avg_loss: 45.0876, kl/H(z|x): 0.0025, recon: 45.0852,time elapsed 278.34s +epoch: 12, iter: 6450, avg_loss: 45.3110, kl/H(z|x): 0.0023, recon: 45.3087,time elapsed 279.39s +kl weight 1.0000 +VAL --- avg_loss: 43.3657, kl/H(z|x): 0.0018, mi: 0.1527, recon: 43.3640, nll: 43.3657, ppl: 51.5389 +0 active units +update best loss +TEST --- avg_loss: 43.3642, kl/H(z|x): 0.0018, mi: -0.1021, recon: 43.3624, nll: 43.3642, ppl: 51.5317 +epoch: 13, iter: 6500, avg_loss: 47.0057, kl/H(z|x): 0.0017, recon: 47.0040,time elapsed 288.82s +epoch: 13, iter: 6550, avg_loss: 44.6513, kl/H(z|x): 0.0027, recon: 44.6486,time elapsed 289.90s +epoch: 13, iter: 6600, avg_loss: 45.1560, kl/H(z|x): 0.0015, recon: 45.1545,time elapsed 291.00s +epoch: 13, iter: 6650, avg_loss: 45.0566, kl/H(z|x): 0.0019, recon: 45.0547,time elapsed 292.29s +epoch: 13, iter: 6700, avg_loss: 45.3155, kl/H(z|x): 0.0051, recon: 45.3104,time elapsed 293.44s +epoch: 13, iter: 6750, avg_loss: 45.5072, kl/H(z|x): 0.0033, recon: 45.5039,time elapsed 294.55s +epoch: 13, iter: 6800, avg_loss: 45.3792, kl/H(z|x): 0.0046, recon: 45.3747,time elapsed 295.67s +epoch: 13, iter: 6850, avg_loss: 45.1679, kl/H(z|x): 0.0026, recon: 45.1653,time elapsed 296.80s +epoch: 13, iter: 6900, avg_loss: 45.7968, kl/H(z|x): 0.0042, recon: 45.7926,time elapsed 297.90s +epoch: 13, iter: 6950, avg_loss: 45.2510, kl/H(z|x): 0.0045, recon: 45.2464,time elapsed 299.03s +kl weight 1.0000 +VAL --- avg_loss: 43.3716, kl/H(z|x): 0.0023, mi: -0.0546, recon: 43.3693, nll: 43.3716, ppl: 51.5668 +0 active units +TEST --- avg_loss: 43.3622, kl/H(z|x): 0.0023, mi: -0.0436, recon: 43.3599, nll: 43.3622, ppl: 51.5224 +epoch: 14, iter: 7000, avg_loss: 41.5074, kl/H(z|x): 0.0023, recon: 41.5051,time elapsed 308.56s +epoch: 14, iter: 7050, avg_loss: 45.4000, kl/H(z|x): 0.0030, recon: 45.3970,time elapsed 309.65s +epoch: 14, iter: 7100, avg_loss: 44.7859, kl/H(z|x): 0.0028, recon: 44.7830,time elapsed 310.71s +epoch: 14, iter: 7150, avg_loss: 45.1264, kl/H(z|x): 0.0026, recon: 45.1238,time elapsed 311.79s +epoch: 14, iter: 7200, avg_loss: 45.0590, kl/H(z|x): 0.0023, recon: 45.0566,time elapsed 312.87s +epoch: 14, iter: 7250, avg_loss: 45.1674, kl/H(z|x): 0.0028, recon: 45.1646,time elapsed 314.04s +epoch: 14, iter: 7300, avg_loss: 45.4094, kl/H(z|x): 0.0038, recon: 45.4056,time elapsed 315.11s +epoch: 14, iter: 7350, avg_loss: 44.7291, kl/H(z|x): 0.0041, recon: 44.7250,time elapsed 316.20s +epoch: 14, iter: 7400, avg_loss: 45.2026, kl/H(z|x): 0.0023, recon: 45.2003,time elapsed 317.33s +epoch: 14, iter: 7450, avg_loss: 45.1990, kl/H(z|x): 0.0022, recon: 45.1968,time elapsed 318.41s +kl weight 1.0000 +VAL --- avg_loss: 43.3621, kl/H(z|x): 0.0028, mi: -0.0980, recon: 43.3593, nll: 43.3621, ppl: 51.5220 +0 active units +update best loss +TEST --- avg_loss: 43.3668, kl/H(z|x): 0.0028, mi: 0.0047, recon: 43.3640, nll: 43.3668, ppl: 51.5442 +epoch: 15, iter: 7500, avg_loss: 45.3504, kl/H(z|x): 0.0028, recon: 45.3475,time elapsed 328.06s +epoch: 15, iter: 7550, avg_loss: 44.6513, kl/H(z|x): 0.0034, recon: 44.6480,time elapsed 329.16s +epoch: 15, iter: 7600, avg_loss: 44.5779, kl/H(z|x): 0.0029, recon: 44.5750,time elapsed 330.27s +epoch: 15, iter: 7650, avg_loss: 45.3756, kl/H(z|x): 0.0024, recon: 45.3732,time elapsed 331.43s +epoch: 15, iter: 7700, avg_loss: 45.0796, kl/H(z|x): 0.0040, recon: 45.0756,time elapsed 332.54s +epoch: 15, iter: 7750, avg_loss: 44.8463, kl/H(z|x): 0.0028, recon: 44.8435,time elapsed 333.66s +epoch: 15, iter: 7800, avg_loss: 45.4175, kl/H(z|x): 0.0040, recon: 45.4135,time elapsed 334.77s +epoch: 15, iter: 7850, avg_loss: 45.3280, kl/H(z|x): 0.0024, recon: 45.3256,time elapsed 335.88s +epoch: 15, iter: 7900, avg_loss: 45.0953, kl/H(z|x): 0.0026, recon: 45.0926,time elapsed 336.97s +epoch: 15, iter: 7950, avg_loss: 45.5351, kl/H(z|x): 0.0019, recon: 45.5332,time elapsed 338.10s +kl weight 1.0000 +VAL --- avg_loss: 43.7256, kl/H(z|x): 0.0015, mi: -0.0466, recon: 43.7241, nll: 43.7256, ppl: 53.2529 +0 active units +TEST --- avg_loss: 43.7168, kl/H(z|x): 0.0015, mi: 0.0051, recon: 43.7153, nll: 43.7168, ppl: 53.2103 +epoch: 16, iter: 8000, avg_loss: 48.4416, kl/H(z|x): 0.0015, recon: 48.4401,time elapsed 347.57s +epoch: 16, iter: 8050, avg_loss: 44.7687, kl/H(z|x): 0.0016, recon: 44.7672,time elapsed 348.68s +epoch: 16, iter: 8100, avg_loss: 44.4672, kl/H(z|x): 0.0016, recon: 44.4656,time elapsed 349.77s +epoch: 16, iter: 8150, avg_loss: 45.1387, kl/H(z|x): 0.0014, recon: 45.1373,time elapsed 350.86s +epoch: 16, iter: 8200, avg_loss: 45.0043, kl/H(z|x): 0.0031, recon: 45.0012,time elapsed 352.23s +epoch: 16, iter: 8250, avg_loss: 45.2002, kl/H(z|x): 0.0018, recon: 45.1984,time elapsed 353.31s +epoch: 16, iter: 8300, avg_loss: 45.3384, kl/H(z|x): 0.0020, recon: 45.3364,time elapsed 354.36s +epoch: 16, iter: 8350, avg_loss: 45.0697, kl/H(z|x): 0.0023, recon: 45.0674,time elapsed 355.48s +epoch: 16, iter: 8400, avg_loss: 45.1004, kl/H(z|x): 0.0015, recon: 45.0988,time elapsed 356.56s +epoch: 16, iter: 8450, avg_loss: 44.7239, kl/H(z|x): 0.0019, recon: 44.7219,time elapsed 357.69s +kl weight 1.0000 +VAL --- avg_loss: 43.3528, kl/H(z|x): 0.0016, mi: -0.0783, recon: 43.3512, nll: 43.3528, ppl: 51.4786 +0 active units +update best loss +TEST --- avg_loss: 43.3674, kl/H(z|x): 0.0016, mi: -0.0893, recon: 43.3658, nll: 43.3674, ppl: 51.5467 +epoch: 17, iter: 8500, avg_loss: 44.1956, kl/H(z|x): 0.0016, recon: 44.1940,time elapsed 367.18s +epoch: 17, iter: 8550, avg_loss: 44.4295, kl/H(z|x): 0.0022, recon: 44.4273,time elapsed 368.29s +epoch: 17, iter: 8600, avg_loss: 45.0237, kl/H(z|x): 0.0025, recon: 45.0211,time elapsed 369.39s +epoch: 17, iter: 8650, avg_loss: 45.2083, kl/H(z|x): 0.0016, recon: 45.2067,time elapsed 370.48s +epoch: 17, iter: 8700, avg_loss: 45.0486, kl/H(z|x): 0.0021, recon: 45.0465,time elapsed 371.57s +epoch: 17, iter: 8750, avg_loss: 45.1202, kl/H(z|x): 0.0022, recon: 45.1180,time elapsed 372.68s +epoch: 17, iter: 8800, avg_loss: 45.5614, kl/H(z|x): 0.0018, recon: 45.5596,time elapsed 373.76s +epoch: 17, iter: 8850, avg_loss: 44.9460, kl/H(z|x): 0.0018, recon: 44.9442,time elapsed 374.91s +epoch: 17, iter: 8900, avg_loss: 45.2071, kl/H(z|x): 0.0018, recon: 45.2053,time elapsed 376.01s +epoch: 17, iter: 8950, avg_loss: 45.1195, kl/H(z|x): 0.0020, recon: 45.1175,time elapsed 377.09s +kl weight 1.0000 +VAL --- avg_loss: 43.5796, kl/H(z|x): 0.0035, mi: 0.1645, recon: 43.5761, nll: 43.5796, ppl: 52.5508 +0 active units +TEST --- avg_loss: 43.5770, kl/H(z|x): 0.0035, mi: -0.0107, recon: 43.5734, nll: 43.5770, ppl: 52.5384 +epoch: 18, iter: 9000, avg_loss: 44.2011, kl/H(z|x): 0.0035, recon: 44.1976,time elapsed 386.56s +epoch: 18, iter: 9050, avg_loss: 45.1314, kl/H(z|x): 0.0018, recon: 45.1296,time elapsed 387.79s +epoch: 18, iter: 9100, avg_loss: 44.8351, kl/H(z|x): 0.0019, recon: 44.8331,time elapsed 388.95s +epoch: 18, iter: 9150, avg_loss: 44.7113, kl/H(z|x): 0.0019, recon: 44.7095,time elapsed 390.06s +epoch: 18, iter: 9200, avg_loss: 44.9135, kl/H(z|x): 0.0016, recon: 44.9119,time elapsed 391.16s +epoch: 18, iter: 9250, avg_loss: 44.8027, kl/H(z|x): 0.0023, recon: 44.8005,time elapsed 392.26s +epoch: 18, iter: 9300, avg_loss: 45.2509, kl/H(z|x): 0.0025, recon: 45.2485,time elapsed 393.38s +epoch: 18, iter: 9350, avg_loss: 45.4845, kl/H(z|x): 0.0023, recon: 45.4823,time elapsed 394.47s +epoch: 18, iter: 9400, avg_loss: 45.1169, kl/H(z|x): 0.0036, recon: 45.1133,time elapsed 395.54s +epoch: 18, iter: 9450, avg_loss: 45.0380, kl/H(z|x): 0.0023, recon: 45.0357,time elapsed 396.64s +kl weight 1.0000 +VAL --- avg_loss: 43.2472, kl/H(z|x): 0.0016, mi: -0.0371, recon: 43.2456, nll: 43.2472, ppl: 50.9865 +0 active units +update best loss +TEST --- avg_loss: 43.2442, kl/H(z|x): 0.0016, mi: -0.0985, recon: 43.2427, nll: 43.2442, ppl: 50.9729 +epoch: 19, iter: 9500, avg_loss: 43.3491, kl/H(z|x): 0.0016, recon: 43.3475,time elapsed 406.21s +epoch: 19, iter: 9550, avg_loss: 44.9036, kl/H(z|x): 0.0019, recon: 44.9017,time elapsed 407.30s +epoch: 19, iter: 9600, avg_loss: 45.0755, kl/H(z|x): 0.0016, recon: 45.0739,time elapsed 408.42s +epoch: 19, iter: 9650, avg_loss: 44.7566, kl/H(z|x): 0.0016, recon: 44.7551,time elapsed 409.54s +epoch: 19, iter: 9700, avg_loss: 44.9752, kl/H(z|x): 0.0015, recon: 44.9737,time elapsed 410.67s +epoch: 19, iter: 9750, avg_loss: 44.8479, kl/H(z|x): 0.0018, recon: 44.8461,time elapsed 411.93s +epoch: 19, iter: 9800, avg_loss: 45.3882, kl/H(z|x): 0.0016, recon: 45.3866,time elapsed 413.19s +epoch: 19, iter: 9850, avg_loss: 44.9943, kl/H(z|x): 0.0016, recon: 44.9927,time elapsed 414.33s +epoch: 19, iter: 9900, avg_loss: 44.6496, kl/H(z|x): 0.0025, recon: 44.6471,time elapsed 415.42s +epoch: 19, iter: 9950, avg_loss: 44.8607, kl/H(z|x): 0.0018, recon: 44.8588,time elapsed 416.51s +kl weight 1.0000 +VAL --- avg_loss: 43.2890, kl/H(z|x): 0.0025, mi: -0.0208, recon: 43.2864, nll: 43.2890, ppl: 51.1807 +0 active units +TEST --- avg_loss: 43.2645, kl/H(z|x): 0.0025, mi: -0.0091, recon: 43.2619, nll: 43.2645, ppl: 51.0668 +epoch: 20, iter: 10000, avg_loss: 46.8134, kl/H(z|x): 0.0025, recon: 46.8108,time elapsed 426.14s +epoch: 20, iter: 10050, avg_loss: 44.5414, kl/H(z|x): 0.0015, recon: 44.5399,time elapsed 427.27s +epoch: 20, iter: 10100, avg_loss: 44.8945, kl/H(z|x): 0.0022, recon: 44.8923,time elapsed 428.39s +epoch: 20, iter: 10150, avg_loss: 44.5182, kl/H(z|x): 0.0016, recon: 44.5166,time elapsed 429.50s +epoch: 20, iter: 10200, avg_loss: 44.9697, kl/H(z|x): 0.0016, recon: 44.9681,time elapsed 430.60s +epoch: 20, iter: 10250, avg_loss: 45.2996, kl/H(z|x): 0.0024, recon: 45.2973,time elapsed 431.72s +epoch: 20, iter: 10300, avg_loss: 45.0469, kl/H(z|x): 0.0023, recon: 45.0446,time elapsed 432.89s +epoch: 20, iter: 10350, avg_loss: 44.9282, kl/H(z|x): 0.0021, recon: 44.9261,time elapsed 434.01s +epoch: 20, iter: 10400, avg_loss: 45.3862, kl/H(z|x): 0.0018, recon: 45.3844,time elapsed 435.10s +epoch: 20, iter: 10450, avg_loss: 44.9310, kl/H(z|x): 0.0017, recon: 44.9292,time elapsed 436.25s +kl weight 1.0000 +VAL --- avg_loss: 43.3190, kl/H(z|x): 0.0026, mi: -0.0782, recon: 43.3164, nll: 43.3190, ppl: 51.3205 +0 active units +TEST --- avg_loss: 43.3047, kl/H(z|x): 0.0026, mi: 0.0224, recon: 43.3021, nll: 43.3047, ppl: 51.2536 +epoch: 21, iter: 10500, avg_loss: 43.5460, kl/H(z|x): 0.0026, recon: 43.5434,time elapsed 445.80s +epoch: 21, iter: 10550, avg_loss: 44.5981, kl/H(z|x): 0.0017, recon: 44.5964,time elapsed 446.89s +epoch: 21, iter: 10600, avg_loss: 45.0290, kl/H(z|x): 0.0018, recon: 45.0272,time elapsed 448.20s +epoch: 21, iter: 10650, avg_loss: 44.9590, kl/H(z|x): 0.0015, recon: 44.9575,time elapsed 449.30s +epoch: 21, iter: 10700, avg_loss: 44.7474, kl/H(z|x): 0.0013, recon: 44.7460,time elapsed 450.43s +epoch: 21, iter: 10750, avg_loss: 44.8494, kl/H(z|x): 0.0014, recon: 44.8480,time elapsed 451.53s +epoch: 21, iter: 10800, avg_loss: 44.7393, kl/H(z|x): 0.0016, recon: 44.7377,time elapsed 452.66s +epoch: 21, iter: 10850, avg_loss: 44.8197, kl/H(z|x): 0.0019, recon: 44.8178,time elapsed 453.75s +epoch: 21, iter: 10900, avg_loss: 45.2948, kl/H(z|x): 0.0025, recon: 45.2923,time elapsed 454.83s +epoch: 21, iter: 10950, avg_loss: 45.0870, kl/H(z|x): 0.0012, recon: 45.0859,time elapsed 455.92s +kl weight 1.0000 +VAL --- avg_loss: 43.4467, kl/H(z|x): 0.0016, mi: -0.0297, recon: 43.4450, nll: 43.4467, ppl: 51.9196 +0 active units +TEST --- avg_loss: 43.4356, kl/H(z|x): 0.0016, mi: 0.0437, recon: 43.4339, nll: 43.4356, ppl: 51.8673 +epoch: 22, iter: 11000, avg_loss: 46.4844, kl/H(z|x): 0.0016, recon: 46.4828,time elapsed 465.42s +epoch: 22, iter: 11050, avg_loss: 44.5962, kl/H(z|x): 0.0016, recon: 44.5946,time elapsed 466.52s +epoch: 22, iter: 11100, avg_loss: 44.8375, kl/H(z|x): 0.0019, recon: 44.8355,time elapsed 467.66s +epoch: 22, iter: 11150, avg_loss: 44.6757, kl/H(z|x): 0.0015, recon: 44.6742,time elapsed 468.76s +epoch: 22, iter: 11200, avg_loss: 45.0933, kl/H(z|x): 0.0015, recon: 45.0918,time elapsed 469.87s +epoch: 22, iter: 11250, avg_loss: 45.0801, kl/H(z|x): 0.0017, recon: 45.0784,time elapsed 470.97s +epoch: 22, iter: 11300, avg_loss: 44.6852, kl/H(z|x): 0.0017, recon: 44.6835,time elapsed 472.09s +epoch: 22, iter: 11350, avg_loss: 45.4787, kl/H(z|x): 0.0016, recon: 45.4771,time elapsed 473.22s +epoch: 22, iter: 11400, avg_loss: 44.6543, kl/H(z|x): 0.0016, recon: 44.6527,time elapsed 474.32s +epoch: 22, iter: 11450, avg_loss: 44.8205, kl/H(z|x): 0.0013, recon: 44.8191,time elapsed 475.42s +kl weight 1.0000 +VAL --- avg_loss: 43.5634, kl/H(z|x): 0.0009, mi: -0.0111, recon: 43.5625, nll: 43.5634, ppl: 52.4734 +0 active units +TEST --- avg_loss: 43.5467, kl/H(z|x): 0.0009, mi: -0.0235, recon: 43.5458, nll: 43.5467, ppl: 52.3937 +epoch: 23, iter: 11500, avg_loss: 48.4146, kl/H(z|x): 0.0009, recon: 48.4137,time elapsed 484.94s +epoch: 23, iter: 11550, avg_loss: 45.0193, kl/H(z|x): 0.0015, recon: 45.0178,time elapsed 486.04s +epoch: 23, iter: 11600, avg_loss: 44.6590, kl/H(z|x): 0.0015, recon: 44.6575,time elapsed 487.14s +epoch: 23, iter: 11650, avg_loss: 44.8284, kl/H(z|x): 0.0015, recon: 44.8269,time elapsed 488.26s +epoch: 23, iter: 11700, avg_loss: 44.7549, kl/H(z|x): 0.0013, recon: 44.7536,time elapsed 489.36s +epoch: 23, iter: 11750, avg_loss: 44.8092, kl/H(z|x): 0.0013, recon: 44.8079,time elapsed 490.46s +epoch: 23, iter: 11800, avg_loss: 44.7132, kl/H(z|x): 0.0016, recon: 44.7116,time elapsed 491.59s +epoch: 23, iter: 11850, avg_loss: 45.3561, kl/H(z|x): 0.0012, recon: 45.3548,time elapsed 492.71s +epoch: 23, iter: 11900, avg_loss: 44.7064, kl/H(z|x): 0.0013, recon: 44.7051,time elapsed 493.79s +epoch: 23, iter: 11950, avg_loss: 44.9692, kl/H(z|x): 0.0013, recon: 44.9679,time elapsed 494.90s +kl weight 1.0000 +VAL --- avg_loss: 43.3256, kl/H(z|x): 0.0012, mi: -0.0326, recon: 43.3244, nll: 43.3256, ppl: 51.3512 +0 active units +new lr: 0.500000 +TEST --- avg_loss: 43.2480, kl/H(z|x): 0.0016, mi: -0.0007, recon: 43.2465, nll: 43.2480, ppl: 50.9905 +epoch: 24, iter: 12000, avg_loss: 46.3091, kl/H(z|x): 0.0016, recon: 46.3076,time elapsed 504.33s +epoch: 24, iter: 12050, avg_loss: 44.2064, kl/H(z|x): 0.0017, recon: 44.2047,time elapsed 505.42s +epoch: 24, iter: 12100, avg_loss: 44.2650, kl/H(z|x): 0.0013, recon: 44.2637,time elapsed 506.51s +epoch: 24, iter: 12150, avg_loss: 44.8498, kl/H(z|x): 0.0013, recon: 44.8485,time elapsed 507.73s +epoch: 24, iter: 12200, avg_loss: 44.0413, kl/H(z|x): 0.0010, recon: 44.0403,time elapsed 508.88s +epoch: 24, iter: 12250, avg_loss: 44.0497, kl/H(z|x): 0.0009, recon: 44.0488,time elapsed 509.98s +epoch: 24, iter: 12300, avg_loss: 43.7201, kl/H(z|x): 0.0009, recon: 43.7193,time elapsed 511.08s +epoch: 24, iter: 12350, avg_loss: 44.5270, kl/H(z|x): 0.0012, recon: 44.5259,time elapsed 512.16s +epoch: 24, iter: 12400, avg_loss: 44.4890, kl/H(z|x): 0.0010, recon: 44.4880,time elapsed 513.28s +epoch: 24, iter: 12450, avg_loss: 44.3600, kl/H(z|x): 0.0009, recon: 44.3591,time elapsed 514.36s +kl weight 1.0000 +VAL --- avg_loss: 42.8909, kl/H(z|x): 0.0007, mi: -0.0319, recon: 42.8902, nll: 42.8909, ppl: 49.3614 +0 active units +update best loss +TEST --- avg_loss: 42.8864, kl/H(z|x): 0.0007, mi: 0.0619, recon: 42.8857, nll: 42.8864, ppl: 49.3415 +epoch: 25, iter: 12500, avg_loss: 43.0568, kl/H(z|x): 0.0007, recon: 43.0561,time elapsed 523.70s +epoch: 25, iter: 12550, avg_loss: 43.7582, kl/H(z|x): 0.0008, recon: 43.7573,time elapsed 524.80s +epoch: 25, iter: 12600, avg_loss: 43.9676, kl/H(z|x): 0.0005, recon: 43.9671,time elapsed 525.90s +epoch: 25, iter: 12650, avg_loss: 44.1848, kl/H(z|x): 0.0005, recon: 44.1843,time elapsed 527.00s +epoch: 25, iter: 12700, avg_loss: 44.3981, kl/H(z|x): 0.0005, recon: 44.3976,time elapsed 528.14s +epoch: 25, iter: 12750, avg_loss: 44.3295, kl/H(z|x): 0.0005, recon: 44.3290,time elapsed 529.22s +epoch: 25, iter: 12800, avg_loss: 44.1709, kl/H(z|x): 0.0004, recon: 44.1705,time elapsed 530.32s +epoch: 25, iter: 12850, avg_loss: 43.7556, kl/H(z|x): 0.0005, recon: 43.7550,time elapsed 531.78s +epoch: 25, iter: 12900, avg_loss: 44.2555, kl/H(z|x): 0.0004, recon: 44.2550,time elapsed 533.00s +epoch: 25, iter: 12950, avg_loss: 44.1798, kl/H(z|x): 0.0004, recon: 44.1793,time elapsed 534.12s +kl weight 1.0000 +VAL --- avg_loss: 42.9512, kl/H(z|x): 0.0005, mi: -0.0766, recon: 42.9507, nll: 42.9512, ppl: 49.6328 +0 active units +TEST --- avg_loss: 42.9562, kl/H(z|x): 0.0005, mi: -0.0393, recon: 42.9557, nll: 42.9562, ppl: 49.6553 +epoch: 26, iter: 13000, avg_loss: 44.2116, kl/H(z|x): 0.0005, recon: 44.2112,time elapsed 543.43s +epoch: 26, iter: 13050, avg_loss: 43.9048, kl/H(z|x): 0.0004, recon: 43.9044,time elapsed 544.51s +epoch: 26, iter: 13100, avg_loss: 43.7069, kl/H(z|x): 0.0004, recon: 43.7065,time elapsed 545.61s +epoch: 26, iter: 13150, avg_loss: 43.7116, kl/H(z|x): 0.0004, recon: 43.7113,time elapsed 546.69s +epoch: 26, iter: 13200, avg_loss: 43.3178, kl/H(z|x): 0.0004, recon: 43.3174,time elapsed 547.82s +epoch: 26, iter: 13250, avg_loss: 44.7637, kl/H(z|x): 0.0005, recon: 44.7632,time elapsed 548.92s +epoch: 26, iter: 13300, avg_loss: 44.1573, kl/H(z|x): 0.0004, recon: 44.1569,time elapsed 550.03s +epoch: 26, iter: 13350, avg_loss: 44.3795, kl/H(z|x): 0.0005, recon: 44.3790,time elapsed 551.11s +epoch: 26, iter: 13400, avg_loss: 44.2767, kl/H(z|x): 0.0005, recon: 44.2761,time elapsed 552.20s +epoch: 26, iter: 13450, avg_loss: 43.9566, kl/H(z|x): 0.0004, recon: 43.9561,time elapsed 553.31s +kl weight 1.0000 +VAL --- avg_loss: 42.8416, kl/H(z|x): 0.0003, mi: 0.0639, recon: 42.8413, nll: 42.8416, ppl: 49.1409 +0 active units +update best loss +TEST --- avg_loss: 42.8492, kl/H(z|x): 0.0003, mi: 0.0448, recon: 42.8488, nll: 42.8492, ppl: 49.1747 +epoch: 27, iter: 13500, avg_loss: 44.2243, kl/H(z|x): 0.0003, recon: 44.2240,time elapsed 562.83s +epoch: 27, iter: 13550, avg_loss: 43.7787, kl/H(z|x): 0.0005, recon: 43.7782,time elapsed 563.92s +epoch: 27, iter: 13600, avg_loss: 43.8797, kl/H(z|x): 0.0004, recon: 43.8793,time elapsed 565.00s +epoch: 27, iter: 13650, avg_loss: 44.2306, kl/H(z|x): 0.0006, recon: 44.2301,time elapsed 566.18s +epoch: 27, iter: 13700, avg_loss: 43.7812, kl/H(z|x): 0.0005, recon: 43.7807,time elapsed 567.29s +epoch: 27, iter: 13750, avg_loss: 43.6520, kl/H(z|x): 0.0003, recon: 43.6517,time elapsed 568.59s +epoch: 27, iter: 13800, avg_loss: 44.3062, kl/H(z|x): 0.0004, recon: 44.3058,time elapsed 569.68s +epoch: 27, iter: 13850, avg_loss: 43.7553, kl/H(z|x): 0.0004, recon: 43.7549,time elapsed 570.80s +epoch: 27, iter: 13900, avg_loss: 44.5534, kl/H(z|x): 0.0004, recon: 44.5529,time elapsed 571.90s +epoch: 27, iter: 13950, avg_loss: 44.0106, kl/H(z|x): 0.0004, recon: 44.0102,time elapsed 573.03s +kl weight 1.0000 +VAL --- avg_loss: 42.8331, kl/H(z|x): 0.0004, mi: 0.1504, recon: 42.8327, nll: 42.8331, ppl: 49.1029 +0 active units +update best loss +TEST --- avg_loss: 42.8320, kl/H(z|x): 0.0004, mi: 0.0753, recon: 42.8316, nll: 42.8320, ppl: 49.0980 +epoch: 28, iter: 14000, avg_loss: 39.6822, kl/H(z|x): 0.0005, recon: 39.6818,time elapsed 582.32s +epoch: 28, iter: 14050, avg_loss: 43.7772, kl/H(z|x): 0.0004, recon: 43.7768,time elapsed 583.51s +epoch: 28, iter: 14100, avg_loss: 44.1639, kl/H(z|x): 0.0004, recon: 44.1635,time elapsed 584.65s +epoch: 28, iter: 14150, avg_loss: 43.8640, kl/H(z|x): 0.0004, recon: 43.8636,time elapsed 585.79s +epoch: 28, iter: 14200, avg_loss: 43.9483, kl/H(z|x): 0.0005, recon: 43.9479,time elapsed 586.92s +epoch: 28, iter: 14250, avg_loss: 43.9844, kl/H(z|x): 0.0006, recon: 43.9838,time elapsed 588.11s +epoch: 28, iter: 14300, avg_loss: 44.2249, kl/H(z|x): 0.0005, recon: 44.2243,time elapsed 589.23s +epoch: 28, iter: 14350, avg_loss: 43.7008, kl/H(z|x): 0.0004, recon: 43.7004,time elapsed 590.35s +epoch: 28, iter: 14400, avg_loss: 43.8595, kl/H(z|x): 0.0004, recon: 43.8591,time elapsed 591.52s +epoch: 28, iter: 14450, avg_loss: 44.3218, kl/H(z|x): 0.0004, recon: 44.3214,time elapsed 592.68s +kl weight 1.0000 +VAL --- avg_loss: 42.8649, kl/H(z|x): 0.0004, mi: -0.0395, recon: 42.8646, nll: 42.8649, ppl: 49.2452 +0 active units +TEST --- avg_loss: 42.8816, kl/H(z|x): 0.0004, mi: 0.1842, recon: 42.8813, nll: 42.8816, ppl: 49.3200 +epoch: 29, iter: 14500, avg_loss: 43.6188, kl/H(z|x): 0.0003, recon: 43.6184,time elapsed 602.00s +epoch: 29, iter: 14550, avg_loss: 43.7076, kl/H(z|x): 0.0004, recon: 43.7072,time elapsed 603.14s +epoch: 29, iter: 14600, avg_loss: 43.7143, kl/H(z|x): 0.0004, recon: 43.7139,time elapsed 604.23s +epoch: 29, iter: 14650, avg_loss: 44.3535, kl/H(z|x): 0.0004, recon: 44.3531,time elapsed 605.30s +epoch: 29, iter: 14700, avg_loss: 43.7975, kl/H(z|x): 0.0004, recon: 43.7971,time elapsed 606.38s +epoch: 29, iter: 14750, avg_loss: 44.0038, kl/H(z|x): 0.0004, recon: 44.0034,time elapsed 607.48s +epoch: 29, iter: 14800, avg_loss: 44.1614, kl/H(z|x): 0.0004, recon: 44.1611,time elapsed 608.57s +epoch: 29, iter: 14850, avg_loss: 43.5950, kl/H(z|x): 0.0003, recon: 43.5947,time elapsed 609.66s +epoch: 29, iter: 14900, avg_loss: 44.0414, kl/H(z|x): 0.0003, recon: 44.0411,time elapsed 610.74s +epoch: 29, iter: 14950, avg_loss: 43.7423, kl/H(z|x): 0.0005, recon: 43.7418,time elapsed 611.83s +kl weight 1.0000 +VAL --- avg_loss: 42.7344, kl/H(z|x): 0.0005, mi: 0.1063, recon: 42.7340, nll: 42.7344, ppl: 48.6645 +0 active units +update best loss +TEST --- avg_loss: 42.7331, kl/H(z|x): 0.0005, mi: -0.0573, recon: 42.7326, nll: 42.7331, ppl: 48.6584 +epoch: 30, iter: 15000, avg_loss: 43.9892, kl/H(z|x): 0.0005, recon: 43.9888,time elapsed 621.32s +epoch: 30, iter: 15050, avg_loss: 43.6377, kl/H(z|x): 0.0005, recon: 43.6372,time elapsed 622.45s +epoch: 30, iter: 15100, avg_loss: 44.1237, kl/H(z|x): 0.0004, recon: 44.1233,time elapsed 623.51s +epoch: 30, iter: 15150, avg_loss: 43.8119, kl/H(z|x): 0.0005, recon: 43.8115,time elapsed 624.59s +epoch: 30, iter: 15200, avg_loss: 43.8647, kl/H(z|x): 0.0004, recon: 43.8643,time elapsed 625.70s +epoch: 30, iter: 15250, avg_loss: 43.9263, kl/H(z|x): 0.0003, recon: 43.9259,time elapsed 626.80s +epoch: 30, iter: 15300, avg_loss: 44.1386, kl/H(z|x): 0.0004, recon: 44.1382,time elapsed 628.34s +epoch: 30, iter: 15350, avg_loss: 44.2637, kl/H(z|x): 0.0004, recon: 44.2633,time elapsed 629.57s +epoch: 30, iter: 15400, avg_loss: 43.6037, kl/H(z|x): 0.0005, recon: 43.6032,time elapsed 630.79s +epoch: 30, iter: 15450, avg_loss: 43.9861, kl/H(z|x): 0.0004, recon: 43.9857,time elapsed 631.99s +kl weight 1.0000 +VAL --- avg_loss: 42.8454, kl/H(z|x): 0.0005, mi: -0.0501, recon: 42.8449, nll: 42.8454, ppl: 49.1579 +0 active units +TEST --- avg_loss: 42.8717, kl/H(z|x): 0.0005, mi: 0.0334, recon: 42.8712, nll: 42.8717, ppl: 49.2755 +epoch: 31, iter: 15500, avg_loss: 49.3075, kl/H(z|x): 0.0005, recon: 49.3070,time elapsed 641.71s +epoch: 31, iter: 15550, avg_loss: 44.0898, kl/H(z|x): 0.0004, recon: 44.0893,time elapsed 642.82s +epoch: 31, iter: 15600, avg_loss: 43.8914, kl/H(z|x): 0.0004, recon: 43.8910,time elapsed 643.90s +epoch: 31, iter: 15650, avg_loss: 43.8410, kl/H(z|x): 0.0004, recon: 43.8405,time elapsed 645.00s +epoch: 31, iter: 15700, avg_loss: 43.4980, kl/H(z|x): 0.0004, recon: 43.4976,time elapsed 646.09s +epoch: 31, iter: 15750, avg_loss: 44.1395, kl/H(z|x): 0.0004, recon: 44.1391,time elapsed 647.17s +epoch: 31, iter: 15800, avg_loss: 43.8627, kl/H(z|x): 0.0004, recon: 43.8623,time elapsed 648.31s +epoch: 31, iter: 15850, avg_loss: 44.0324, kl/H(z|x): 0.0004, recon: 44.0320,time elapsed 649.39s +epoch: 31, iter: 15900, avg_loss: 43.8421, kl/H(z|x): 0.0005, recon: 43.8416,time elapsed 650.48s +epoch: 31, iter: 15950, avg_loss: 44.5026, kl/H(z|x): 0.0004, recon: 44.5022,time elapsed 651.68s +kl weight 1.0000 +VAL --- avg_loss: 42.8186, kl/H(z|x): 0.0004, mi: -0.0377, recon: 42.8182, nll: 42.8186, ppl: 49.0381 +0 active units +TEST --- avg_loss: 42.8374, kl/H(z|x): 0.0004, mi: 0.0111, recon: 42.8370, nll: 42.8374, ppl: 49.1222 +epoch: 32, iter: 16000, avg_loss: 43.7023, kl/H(z|x): 0.0004, recon: 43.7019,time elapsed 661.15s +epoch: 32, iter: 16050, avg_loss: 43.7157, kl/H(z|x): 0.0003, recon: 43.7153,time elapsed 662.24s +epoch: 32, iter: 16100, avg_loss: 43.2802, kl/H(z|x): 0.0004, recon: 43.2798,time elapsed 663.37s +epoch: 32, iter: 16150, avg_loss: 44.0081, kl/H(z|x): 0.0004, recon: 44.0077,time elapsed 664.48s +epoch: 32, iter: 16200, avg_loss: 44.0434, kl/H(z|x): 0.0004, recon: 44.0430,time elapsed 665.55s +epoch: 32, iter: 16250, avg_loss: 44.1416, kl/H(z|x): 0.0004, recon: 44.1411,time elapsed 666.63s +epoch: 32, iter: 16300, avg_loss: 43.9179, kl/H(z|x): 0.0005, recon: 43.9174,time elapsed 667.74s +epoch: 32, iter: 16350, avg_loss: 43.9227, kl/H(z|x): 0.0005, recon: 43.9222,time elapsed 668.83s +epoch: 32, iter: 16400, avg_loss: 43.8702, kl/H(z|x): 0.0005, recon: 43.8697,time elapsed 669.93s +epoch: 32, iter: 16450, avg_loss: 43.8965, kl/H(z|x): 0.0006, recon: 43.8959,time elapsed 671.01s +kl weight 1.0000 +VAL --- avg_loss: 42.8250, kl/H(z|x): 0.0004, mi: 0.0814, recon: 42.8246, nll: 42.8250, ppl: 49.0668 +0 active units +TEST --- avg_loss: 42.8447, kl/H(z|x): 0.0004, mi: 0.0192, recon: 42.8443, nll: 42.8447, ppl: 49.1547 +epoch: 33, iter: 16500, avg_loss: 44.2517, kl/H(z|x): 0.0004, recon: 44.2513,time elapsed 680.38s +epoch: 33, iter: 16550, avg_loss: 43.6145, kl/H(z|x): 0.0004, recon: 43.6141,time elapsed 681.48s +epoch: 33, iter: 16600, avg_loss: 44.2139, kl/H(z|x): 0.0004, recon: 44.2134,time elapsed 682.60s +epoch: 33, iter: 16650, avg_loss: 43.5590, kl/H(z|x): 0.0005, recon: 43.5585,time elapsed 683.69s +epoch: 33, iter: 16700, avg_loss: 43.5167, kl/H(z|x): 0.0004, recon: 43.5164,time elapsed 684.78s +epoch: 33, iter: 16750, avg_loss: 44.3663, kl/H(z|x): 0.0004, recon: 44.3659,time elapsed 685.95s +epoch: 33, iter: 16800, avg_loss: 43.3762, kl/H(z|x): 0.0005, recon: 43.3757,time elapsed 687.36s +epoch: 33, iter: 16850, avg_loss: 43.6989, kl/H(z|x): 0.0005, recon: 43.6984,time elapsed 688.73s +epoch: 33, iter: 16900, avg_loss: 44.0670, kl/H(z|x): 0.0004, recon: 44.0666,time elapsed 689.83s +epoch: 33, iter: 16950, avg_loss: 44.3268, kl/H(z|x): 0.0004, recon: 44.3263,time elapsed 690.96s +kl weight 1.0000 +VAL --- avg_loss: 42.7766, kl/H(z|x): 0.0005, mi: 0.0011, recon: 42.7761, nll: 42.7766, ppl: 48.8513 +0 active units +TEST --- avg_loss: 42.7852, kl/H(z|x): 0.0005, mi: 0.0312, recon: 42.7847, nll: 42.7852, ppl: 48.8897 +epoch: 34, iter: 17000, avg_loss: 45.9578, kl/H(z|x): 0.0005, recon: 45.9573,time elapsed 700.26s +epoch: 34, iter: 17050, avg_loss: 43.4863, kl/H(z|x): 0.0005, recon: 43.4858,time elapsed 701.34s +epoch: 34, iter: 17100, avg_loss: 43.1265, kl/H(z|x): 0.0004, recon: 43.1261,time elapsed 702.45s +epoch: 34, iter: 17150, avg_loss: 44.2183, kl/H(z|x): 0.0004, recon: 44.2179,time elapsed 703.54s +epoch: 34, iter: 17200, avg_loss: 44.0181, kl/H(z|x): 0.0004, recon: 44.0177,time elapsed 704.61s +epoch: 34, iter: 17250, avg_loss: 43.8095, kl/H(z|x): 0.0004, recon: 43.8091,time elapsed 705.71s +epoch: 34, iter: 17300, avg_loss: 43.8077, kl/H(z|x): 0.0005, recon: 43.8073,time elapsed 706.89s +epoch: 34, iter: 17350, avg_loss: 43.9956, kl/H(z|x): 0.0005, recon: 43.9951,time elapsed 708.04s +epoch: 34, iter: 17400, avg_loss: 44.3191, kl/H(z|x): 0.0004, recon: 44.3187,time elapsed 709.14s +epoch: 34, iter: 17450, avg_loss: 43.7856, kl/H(z|x): 0.0004, recon: 43.7852,time elapsed 710.23s +kl weight 1.0000 +VAL --- avg_loss: 42.8082, kl/H(z|x): 0.0005, mi: -0.0863, recon: 42.8077, nll: 42.8082, ppl: 48.9918 +0 active units +new lr: 0.250000 +TEST --- avg_loss: 42.7279, kl/H(z|x): 0.0005, mi: -0.0938, recon: 42.7275, nll: 42.7279, ppl: 48.6357 +epoch: 35, iter: 17500, avg_loss: 41.7098, kl/H(z|x): 0.0005, recon: 41.7093,time elapsed 719.87s +epoch: 35, iter: 17550, avg_loss: 43.6619, kl/H(z|x): 0.0004, recon: 43.6614,time elapsed 720.96s +epoch: 35, iter: 17600, avg_loss: 43.7670, kl/H(z|x): 0.0004, recon: 43.7666,time elapsed 722.03s +epoch: 35, iter: 17650, avg_loss: 43.6828, kl/H(z|x): 0.0003, recon: 43.6825,time elapsed 723.16s +epoch: 35, iter: 17700, avg_loss: 43.8141, kl/H(z|x): 0.0003, recon: 43.8138,time elapsed 724.24s +epoch: 35, iter: 17750, avg_loss: 43.3714, kl/H(z|x): 0.0003, recon: 43.3711,time elapsed 725.33s +epoch: 35, iter: 17800, avg_loss: 43.9177, kl/H(z|x): 0.0003, recon: 43.9174,time elapsed 726.41s +epoch: 35, iter: 17850, avg_loss: 43.1331, kl/H(z|x): 0.0003, recon: 43.1329,time elapsed 727.53s +epoch: 35, iter: 17900, avg_loss: 43.3851, kl/H(z|x): 0.0002, recon: 43.3849,time elapsed 728.62s +epoch: 35, iter: 17950, avg_loss: 43.5729, kl/H(z|x): 0.0002, recon: 43.5727,time elapsed 729.68s +kl weight 1.0000 +VAL --- avg_loss: 42.6877, kl/H(z|x): 0.0002, mi: -0.1456, recon: 42.6875, nll: 42.6877, ppl: 48.4583 +0 active units +update best loss +TEST --- avg_loss: 42.6792, kl/H(z|x): 0.0002, mi: -0.0167, recon: 42.6791, nll: 42.6792, ppl: 48.4209 +epoch: 36, iter: 18000, avg_loss: 43.9516, kl/H(z|x): 0.0002, recon: 43.9514,time elapsed 739.03s +epoch: 36, iter: 18050, avg_loss: 43.7369, kl/H(z|x): 0.0002, recon: 43.7367,time elapsed 740.13s +epoch: 36, iter: 18100, avg_loss: 43.5138, kl/H(z|x): 0.0003, recon: 43.5136,time elapsed 741.29s +epoch: 36, iter: 18150, avg_loss: 43.5765, kl/H(z|x): 0.0002, recon: 43.5763,time elapsed 742.49s +epoch: 36, iter: 18200, avg_loss: 43.2139, kl/H(z|x): 0.0002, recon: 43.2137,time elapsed 743.57s +epoch: 36, iter: 18250, avg_loss: 43.4710, kl/H(z|x): 0.0002, recon: 43.4709,time elapsed 744.67s +epoch: 36, iter: 18300, avg_loss: 43.4375, kl/H(z|x): 0.0002, recon: 43.4373,time elapsed 745.75s +epoch: 36, iter: 18350, avg_loss: 43.5230, kl/H(z|x): 0.0001, recon: 43.5228,time elapsed 746.84s +epoch: 36, iter: 18400, avg_loss: 43.9662, kl/H(z|x): 0.0001, recon: 43.9661,time elapsed 748.13s +epoch: 36, iter: 18450, avg_loss: 43.7149, kl/H(z|x): 0.0001, recon: 43.7147,time elapsed 749.24s +kl weight 1.0000 +VAL --- avg_loss: 42.6654, kl/H(z|x): 0.0002, mi: -0.0094, recon: 42.6652, nll: 42.6654, ppl: 48.3601 +0 active units +update best loss +TEST --- avg_loss: 42.6540, kl/H(z|x): 0.0002, mi: -0.0618, recon: 42.6538, nll: 42.6540, ppl: 48.3100 +epoch: 37, iter: 18500, avg_loss: 40.8473, kl/H(z|x): 0.0002, recon: 40.8471,time elapsed 758.54s +epoch: 37, iter: 18550, avg_loss: 43.1240, kl/H(z|x): 0.0002, recon: 43.1239,time elapsed 759.63s +epoch: 37, iter: 18600, avg_loss: 43.6882, kl/H(z|x): 0.0002, recon: 43.6880,time elapsed 760.72s +epoch: 37, iter: 18650, avg_loss: 43.5877, kl/H(z|x): 0.0001, recon: 43.5876,time elapsed 761.81s +epoch: 37, iter: 18700, avg_loss: 44.0302, kl/H(z|x): 0.0001, recon: 44.0301,time elapsed 762.93s +epoch: 37, iter: 18750, avg_loss: 43.6260, kl/H(z|x): 0.0001, recon: 43.6258,time elapsed 764.03s +epoch: 37, iter: 18800, avg_loss: 43.3254, kl/H(z|x): 0.0001, recon: 43.3252,time elapsed 765.13s +epoch: 37, iter: 18850, avg_loss: 43.7614, kl/H(z|x): 0.0002, recon: 43.7612,time elapsed 766.22s +epoch: 37, iter: 18900, avg_loss: 43.4881, kl/H(z|x): 0.0002, recon: 43.4879,time elapsed 767.32s +epoch: 37, iter: 18950, avg_loss: 43.3303, kl/H(z|x): 0.0002, recon: 43.3301,time elapsed 768.47s +kl weight 1.0000 +VAL --- avg_loss: 42.6456, kl/H(z|x): 0.0002, mi: -0.0024, recon: 42.6454, nll: 42.6456, ppl: 48.2729 +0 active units +update best loss +TEST --- avg_loss: 42.6398, kl/H(z|x): 0.0002, mi: -0.0623, recon: 42.6396, nll: 42.6398, ppl: 48.2474 +epoch: 38, iter: 19000, avg_loss: 39.2798, kl/H(z|x): 0.0002, recon: 39.2796,time elapsed 777.94s +epoch: 38, iter: 19050, avg_loss: 43.5151, kl/H(z|x): 0.0002, recon: 43.5150,time elapsed 779.02s +epoch: 38, iter: 19100, avg_loss: 43.5278, kl/H(z|x): 0.0001, recon: 43.5277,time elapsed 780.12s +epoch: 38, iter: 19150, avg_loss: 43.0568, kl/H(z|x): 0.0001, recon: 43.0567,time elapsed 781.20s +epoch: 38, iter: 19200, avg_loss: 43.2855, kl/H(z|x): 0.0001, recon: 43.2854,time elapsed 782.34s +epoch: 38, iter: 19250, avg_loss: 43.9021, kl/H(z|x): 0.0001, recon: 43.9020,time elapsed 783.41s +epoch: 38, iter: 19300, avg_loss: 43.6581, kl/H(z|x): 0.0001, recon: 43.6580,time elapsed 784.47s +epoch: 38, iter: 19350, avg_loss: 43.7079, kl/H(z|x): 0.0001, recon: 43.7078,time elapsed 785.55s +epoch: 38, iter: 19400, avg_loss: 43.3235, kl/H(z|x): 0.0001, recon: 43.3233,time elapsed 786.63s +epoch: 38, iter: 19450, avg_loss: 43.6018, kl/H(z|x): 0.0001, recon: 43.6017,time elapsed 787.74s +kl weight 1.0000 +VAL --- avg_loss: 42.6433, kl/H(z|x): 0.0001, mi: 0.0521, recon: 42.6432, nll: 42.6433, ppl: 48.2628 +0 active units +update best loss +TEST --- avg_loss: 42.6409, kl/H(z|x): 0.0001, mi: -0.0322, recon: 42.6408, nll: 42.6409, ppl: 48.2525 +epoch: 39, iter: 19500, avg_loss: 42.6294, kl/H(z|x): 0.0001, recon: 42.6293,time elapsed 797.07s +epoch: 39, iter: 19550, avg_loss: 43.3645, kl/H(z|x): 0.0001, recon: 43.3644,time elapsed 798.21s +epoch: 39, iter: 19600, avg_loss: 43.5161, kl/H(z|x): 0.0002, recon: 43.5160,time elapsed 799.30s +epoch: 39, iter: 19650, avg_loss: 43.6186, kl/H(z|x): 0.0001, recon: 43.6185,time elapsed 800.42s +epoch: 39, iter: 19700, avg_loss: 43.1040, kl/H(z|x): 0.0001, recon: 43.1039,time elapsed 801.51s +epoch: 39, iter: 19750, avg_loss: 43.1631, kl/H(z|x): 0.0001, recon: 43.1630,time elapsed 802.64s +epoch: 39, iter: 19800, avg_loss: 43.8840, kl/H(z|x): 0.0001, recon: 43.8839,time elapsed 803.72s +epoch: 39, iter: 19850, avg_loss: 43.0590, kl/H(z|x): 0.0002, recon: 43.0588,time elapsed 804.80s +epoch: 39, iter: 19900, avg_loss: 43.7155, kl/H(z|x): 0.0001, recon: 43.7154,time elapsed 805.88s +epoch: 39, iter: 19950, avg_loss: 43.7701, kl/H(z|x): 0.0001, recon: 43.7700,time elapsed 807.06s +kl weight 1.0000 +VAL --- avg_loss: 42.6654, kl/H(z|x): 0.0001, mi: -0.0592, recon: 42.6653, nll: 42.6654, ppl: 48.3601 +0 active units +TEST --- avg_loss: 42.6620, kl/H(z|x): 0.0001, mi: 0.0123, recon: 42.6618, nll: 42.6620, ppl: 48.3449 +epoch: 40, iter: 20000, avg_loss: 40.8213, kl/H(z|x): 0.0001, recon: 40.8212,time elapsed 816.66s +epoch: 40, iter: 20050, avg_loss: 43.6310, kl/H(z|x): 0.0001, recon: 43.6309,time elapsed 817.78s +epoch: 40, iter: 20100, avg_loss: 43.3006, kl/H(z|x): 0.0001, recon: 43.3004,time elapsed 818.86s +epoch: 40, iter: 20150, avg_loss: 43.5762, kl/H(z|x): 0.0001, recon: 43.5761,time elapsed 819.95s +epoch: 40, iter: 20200, avg_loss: 43.1754, kl/H(z|x): 0.0001, recon: 43.1753,time elapsed 821.04s +epoch: 40, iter: 20250, avg_loss: 43.1876, kl/H(z|x): 0.0001, recon: 43.1875,time elapsed 822.11s +epoch: 40, iter: 20300, avg_loss: 43.6291, kl/H(z|x): 0.0001, recon: 43.6290,time elapsed 823.24s +epoch: 40, iter: 20350, avg_loss: 43.5045, kl/H(z|x): 0.0001, recon: 43.5044,time elapsed 824.33s +epoch: 40, iter: 20400, avg_loss: 43.7003, kl/H(z|x): 0.0001, recon: 43.7002,time elapsed 825.44s +epoch: 40, iter: 20450, avg_loss: 43.0749, kl/H(z|x): 0.0001, recon: 43.0748,time elapsed 826.51s +kl weight 1.0000 +VAL --- avg_loss: 42.6483, kl/H(z|x): 0.0001, mi: 0.1020, recon: 42.6482, nll: 42.6483, ppl: 48.2851 +0 active units +TEST --- avg_loss: 42.6362, kl/H(z|x): 0.0001, mi: 0.0486, recon: 42.6361, nll: 42.6362, ppl: 48.2319 +epoch: 41, iter: 20500, avg_loss: 46.2897, kl/H(z|x): 0.0001, recon: 46.2896,time elapsed 836.49s +epoch: 41, iter: 20550, avg_loss: 43.3702, kl/H(z|x): 0.0001, recon: 43.3701,time elapsed 837.65s +epoch: 41, iter: 20600, avg_loss: 43.3133, kl/H(z|x): 0.0001, recon: 43.3132,time elapsed 838.74s +epoch: 41, iter: 20650, avg_loss: 43.0952, kl/H(z|x): 0.0001, recon: 43.0951,time elapsed 839.82s +epoch: 41, iter: 20700, avg_loss: 43.4714, kl/H(z|x): 0.0001, recon: 43.4712,time elapsed 840.93s +epoch: 41, iter: 20750, avg_loss: 43.2933, kl/H(z|x): 0.0001, recon: 43.2931,time elapsed 842.23s +epoch: 41, iter: 20800, avg_loss: 43.1513, kl/H(z|x): 0.0001, recon: 43.1511,time elapsed 843.60s +epoch: 41, iter: 20850, avg_loss: 43.4006, kl/H(z|x): 0.0002, recon: 43.4004,time elapsed 845.02s +epoch: 41, iter: 20900, avg_loss: 43.9970, kl/H(z|x): 0.0002, recon: 43.9968,time elapsed 846.14s +epoch: 41, iter: 20950, avg_loss: 43.5226, kl/H(z|x): 0.0002, recon: 43.5224,time elapsed 847.26s +kl weight 1.0000 +VAL --- avg_loss: 42.6132, kl/H(z|x): 0.0001, mi: -0.0262, recon: 42.6130, nll: 42.6132, ppl: 48.1309 +0 active units +update best loss +TEST --- avg_loss: 42.6104, kl/H(z|x): 0.0001, mi: 0.0578, recon: 42.6103, nll: 42.6104, ppl: 48.1187 +epoch: 42, iter: 21000, avg_loss: 45.0442, kl/H(z|x): 0.0001, recon: 45.0440,time elapsed 856.74s +epoch: 42, iter: 21050, avg_loss: 43.2709, kl/H(z|x): 0.0001, recon: 43.2708,time elapsed 857.93s +epoch: 42, iter: 21100, avg_loss: 43.5428, kl/H(z|x): 0.0001, recon: 43.5427,time elapsed 859.10s +epoch: 42, iter: 21150, avg_loss: 43.0828, kl/H(z|x): 0.0001, recon: 43.0827,time elapsed 860.20s +epoch: 42, iter: 21200, avg_loss: 43.4841, kl/H(z|x): 0.0001, recon: 43.4839,time elapsed 861.33s +epoch: 42, iter: 21250, avg_loss: 43.1936, kl/H(z|x): 0.0001, recon: 43.1935,time elapsed 862.47s +epoch: 42, iter: 21300, avg_loss: 43.5686, kl/H(z|x): 0.0001, recon: 43.5684,time elapsed 863.68s +epoch: 42, iter: 21350, avg_loss: 43.8472, kl/H(z|x): 0.0001, recon: 43.8471,time elapsed 864.78s +epoch: 42, iter: 21400, avg_loss: 43.1563, kl/H(z|x): 0.0001, recon: 43.1562,time elapsed 865.87s +epoch: 42, iter: 21450, avg_loss: 43.3632, kl/H(z|x): 0.0001, recon: 43.3630,time elapsed 866.99s +kl weight 1.0000 +VAL --- avg_loss: 42.6019, kl/H(z|x): 0.0001, mi: 0.0535, recon: 42.6018, nll: 42.6019, ppl: 48.0818 +0 active units +update best loss +TEST --- avg_loss: 42.6078, kl/H(z|x): 0.0001, mi: 0.0234, recon: 42.6077, nll: 42.6078, ppl: 48.1076 +epoch: 43, iter: 21500, avg_loss: 42.9676, kl/H(z|x): 0.0001, recon: 42.9675,time elapsed 876.50s +epoch: 43, iter: 21550, avg_loss: 43.4092, kl/H(z|x): 0.0001, recon: 43.4091,time elapsed 877.58s +epoch: 43, iter: 21600, avg_loss: 43.5033, kl/H(z|x): 0.0001, recon: 43.5032,time elapsed 878.69s +epoch: 43, iter: 21650, avg_loss: 43.5289, kl/H(z|x): 0.0002, recon: 43.5287,time elapsed 879.77s +epoch: 43, iter: 21700, avg_loss: 42.9456, kl/H(z|x): 0.0002, recon: 42.9455,time elapsed 880.87s +epoch: 43, iter: 21750, avg_loss: 43.5842, kl/H(z|x): 0.0002, recon: 43.5841,time elapsed 881.98s +epoch: 43, iter: 21800, avg_loss: 43.6099, kl/H(z|x): 0.0001, recon: 43.6098,time elapsed 883.08s +epoch: 43, iter: 21850, avg_loss: 42.9298, kl/H(z|x): 0.0001, recon: 42.9297,time elapsed 884.16s +epoch: 43, iter: 21900, avg_loss: 43.4812, kl/H(z|x): 0.0001, recon: 43.4811,time elapsed 885.24s +epoch: 43, iter: 21950, avg_loss: 43.5893, kl/H(z|x): 0.0001, recon: 43.5892,time elapsed 886.33s +kl weight 1.0000 +VAL --- avg_loss: 42.5868, kl/H(z|x): 0.0002, mi: -0.0469, recon: 42.5866, nll: 42.5868, ppl: 48.0157 +0 active units +update best loss +TEST --- avg_loss: 42.5918, kl/H(z|x): 0.0002, mi: -0.0134, recon: 42.5916, nll: 42.5918, ppl: 48.0376 +epoch: 44, iter: 22000, avg_loss: 41.1623, kl/H(z|x): 0.0002, recon: 41.1621,time elapsed 896.15s +epoch: 44, iter: 22050, avg_loss: 43.6225, kl/H(z|x): 0.0002, recon: 43.6223,time elapsed 897.24s +epoch: 44, iter: 22100, avg_loss: 43.0115, kl/H(z|x): 0.0001, recon: 43.0113,time elapsed 898.35s +epoch: 44, iter: 22150, avg_loss: 43.8803, kl/H(z|x): 0.0001, recon: 43.8802,time elapsed 899.44s +epoch: 44, iter: 22200, avg_loss: 43.7196, kl/H(z|x): 0.0001, recon: 43.7195,time elapsed 900.54s +epoch: 44, iter: 22250, avg_loss: 43.1994, kl/H(z|x): 0.0001, recon: 43.1993,time elapsed 901.65s +epoch: 44, iter: 22300, avg_loss: 43.2972, kl/H(z|x): 0.0001, recon: 43.2971,time elapsed 902.77s +epoch: 44, iter: 22350, avg_loss: 43.2614, kl/H(z|x): 0.0001, recon: 43.2613,time elapsed 903.85s +epoch: 44, iter: 22400, avg_loss: 43.0387, kl/H(z|x): 0.0001, recon: 43.0386,time elapsed 904.93s +epoch: 44, iter: 22450, avg_loss: 43.5179, kl/H(z|x): 0.0001, recon: 43.5178,time elapsed 906.00s +kl weight 1.0000 +VAL --- avg_loss: 42.6205, kl/H(z|x): 0.0001, mi: -0.0640, recon: 42.6204, nll: 42.6205, ppl: 48.1632 +0 active units +TEST --- avg_loss: 42.6170, kl/H(z|x): 0.0001, mi: 0.0838, recon: 42.6168, nll: 42.6170, ppl: 48.1476 +epoch: 45, iter: 22500, avg_loss: 45.8301, kl/H(z|x): 0.0001, recon: 45.8299,time elapsed 915.37s +epoch: 45, iter: 22550, avg_loss: 43.7427, kl/H(z|x): 0.0001, recon: 43.7426,time elapsed 916.46s +epoch: 45, iter: 22600, avg_loss: 43.1537, kl/H(z|x): 0.0001, recon: 43.1536,time elapsed 917.58s +epoch: 45, iter: 22650, avg_loss: 42.6766, kl/H(z|x): 0.0001, recon: 42.6765,time elapsed 918.67s +epoch: 45, iter: 22700, avg_loss: 43.2644, kl/H(z|x): 0.0001, recon: 43.2643,time elapsed 919.77s +epoch: 45, iter: 22750, avg_loss: 43.8260, kl/H(z|x): 0.0001, recon: 43.8259,time elapsed 920.87s +epoch: 45, iter: 22800, avg_loss: 43.0574, kl/H(z|x): 0.0001, recon: 43.0573,time elapsed 922.02s +epoch: 45, iter: 22850, avg_loss: 43.5961, kl/H(z|x): 0.0001, recon: 43.5960,time elapsed 923.16s +epoch: 45, iter: 22900, avg_loss: 43.7173, kl/H(z|x): 0.0001, recon: 43.7172,time elapsed 924.25s +epoch: 45, iter: 22950, avg_loss: 43.1578, kl/H(z|x): 0.0002, recon: 43.1577,time elapsed 925.31s +kl weight 1.0000 +VAL --- avg_loss: 42.6148, kl/H(z|x): 0.0001, mi: -0.0306, recon: 42.6146, nll: 42.6148, ppl: 48.1380 +0 active units +TEST --- avg_loss: 42.6229, kl/H(z|x): 0.0001, mi: 0.0225, recon: 42.6227, nll: 42.6229, ppl: 48.1734 +epoch: 46, iter: 23000, avg_loss: 45.8865, kl/H(z|x): 0.0001, recon: 45.8864,time elapsed 935.17s +epoch: 46, iter: 23050, avg_loss: 42.8259, kl/H(z|x): 0.0002, recon: 42.8257,time elapsed 936.26s +epoch: 46, iter: 23100, avg_loss: 43.3159, kl/H(z|x): 0.0001, recon: 43.3158,time elapsed 937.35s +epoch: 46, iter: 23150, avg_loss: 43.1671, kl/H(z|x): 0.0001, recon: 43.1670,time elapsed 938.47s +epoch: 46, iter: 23200, avg_loss: 43.3078, kl/H(z|x): 0.0001, recon: 43.3077,time elapsed 939.55s +epoch: 46, iter: 23250, avg_loss: 43.6324, kl/H(z|x): 0.0001, recon: 43.6323,time elapsed 940.64s +epoch: 46, iter: 23300, avg_loss: 43.4459, kl/H(z|x): 0.0001, recon: 43.4458,time elapsed 941.71s +epoch: 46, iter: 23350, avg_loss: 43.8197, kl/H(z|x): 0.0002, recon: 43.8195,time elapsed 942.83s +epoch: 46, iter: 23400, avg_loss: 43.6614, kl/H(z|x): 0.0001, recon: 43.6613,time elapsed 943.92s +epoch: 46, iter: 23450, avg_loss: 43.3612, kl/H(z|x): 0.0001, recon: 43.3610,time elapsed 945.02s +kl weight 1.0000 +VAL --- avg_loss: 42.5967, kl/H(z|x): 0.0001, mi: -0.0161, recon: 42.5966, nll: 42.5967, ppl: 48.0588 +0 active units +TEST --- avg_loss: 42.6032, kl/H(z|x): 0.0001, mi: -0.0448, recon: 42.6031, nll: 42.6032, ppl: 48.0873 +epoch: 47, iter: 23500, avg_loss: 44.5862, kl/H(z|x): 0.0001, recon: 44.5860,time elapsed 954.40s +epoch: 47, iter: 23550, avg_loss: 43.0005, kl/H(z|x): 0.0001, recon: 43.0004,time elapsed 955.51s +epoch: 47, iter: 23600, avg_loss: 43.6972, kl/H(z|x): 0.0001, recon: 43.6971,time elapsed 956.60s +epoch: 47, iter: 23650, avg_loss: 43.2438, kl/H(z|x): 0.0001, recon: 43.2437,time elapsed 957.72s +epoch: 47, iter: 23700, avg_loss: 43.3149, kl/H(z|x): 0.0001, recon: 43.3147,time elapsed 958.81s +epoch: 47, iter: 23750, avg_loss: 43.3793, kl/H(z|x): 0.0001, recon: 43.3792,time elapsed 959.92s +epoch: 47, iter: 23800, avg_loss: 43.4107, kl/H(z|x): 0.0002, recon: 43.4105,time elapsed 960.99s +epoch: 47, iter: 23850, avg_loss: 43.8430, kl/H(z|x): 0.0001, recon: 43.8428,time elapsed 962.09s +epoch: 47, iter: 23900, avg_loss: 42.8022, kl/H(z|x): 0.0001, recon: 42.8021,time elapsed 963.20s +epoch: 47, iter: 23950, avg_loss: 43.4718, kl/H(z|x): 0.0001, recon: 43.4717,time elapsed 964.31s +kl weight 1.0000 +VAL --- avg_loss: 42.5876, kl/H(z|x): 0.0001, mi: -0.0421, recon: 42.5875, nll: 42.5876, ppl: 48.0192 +0 active units +TEST --- avg_loss: 42.5844, kl/H(z|x): 0.0001, mi: -0.0502, recon: 42.5842, nll: 42.5844, ppl: 48.0050 +epoch: 48, iter: 24000, avg_loss: 44.4487, kl/H(z|x): 0.0001, recon: 44.4486,time elapsed 973.63s +epoch: 48, iter: 24050, avg_loss: 42.7031, kl/H(z|x): 0.0001, recon: 42.7030,time elapsed 974.70s +epoch: 48, iter: 24100, avg_loss: 43.6009, kl/H(z|x): 0.0001, recon: 43.6008,time elapsed 975.78s +epoch: 48, iter: 24150, avg_loss: 43.3372, kl/H(z|x): 0.0001, recon: 43.3371,time elapsed 976.87s +epoch: 48, iter: 24200, avg_loss: 43.4364, kl/H(z|x): 0.0001, recon: 43.4363,time elapsed 978.00s +epoch: 48, iter: 24250, avg_loss: 43.0002, kl/H(z|x): 0.0001, recon: 43.0001,time elapsed 979.08s +epoch: 48, iter: 24300, avg_loss: 43.3012, kl/H(z|x): 0.0001, recon: 43.3010,time elapsed 980.16s +epoch: 48, iter: 24350, avg_loss: 43.4798, kl/H(z|x): 0.0001, recon: 43.4796,time elapsed 981.24s +epoch: 48, iter: 24400, avg_loss: 43.6218, kl/H(z|x): 0.0001, recon: 43.6216,time elapsed 982.42s +epoch: 48, iter: 24450, avg_loss: 43.2860, kl/H(z|x): 0.0001, recon: 43.2858,time elapsed 983.52s +kl weight 1.0000 +VAL --- avg_loss: 42.5813, kl/H(z|x): 0.0001, mi: 0.0251, recon: 42.5811, nll: 42.5813, ppl: 47.9915 +0 active units +update best loss +TEST --- avg_loss: 42.5769, kl/H(z|x): 0.0001, mi: 0.0155, recon: 42.5768, nll: 42.5769, ppl: 47.9726 +epoch: 49, iter: 24500, avg_loss: 42.8160, kl/H(z|x): 0.0001, recon: 42.8159,time elapsed 992.94s +epoch: 49, iter: 24550, avg_loss: 43.3220, kl/H(z|x): 0.0001, recon: 43.3219,time elapsed 994.02s +epoch: 49, iter: 24600, avg_loss: 43.2801, kl/H(z|x): 0.0001, recon: 43.2799,time elapsed 995.12s +epoch: 49, iter: 24650, avg_loss: 43.4456, kl/H(z|x): 0.0001, recon: 43.4455,time elapsed 996.21s +epoch: 49, iter: 24700, avg_loss: 43.3676, kl/H(z|x): 0.0001, recon: 43.3675,time elapsed 997.28s +epoch: 49, iter: 24750, avg_loss: 42.8345, kl/H(z|x): 0.0001, recon: 42.8343,time elapsed 998.40s +epoch: 49, iter: 24800, avg_loss: 43.6380, kl/H(z|x): 0.0001, recon: 43.6378,time elapsed 999.48s +epoch: 49, iter: 24850, avg_loss: 43.1312, kl/H(z|x): 0.0001, recon: 43.1311,time elapsed 1000.56s +epoch: 49, iter: 24900, avg_loss: 43.2566, kl/H(z|x): 0.0001, recon: 43.2565,time elapsed 1001.64s +epoch: 49, iter: 24950, avg_loss: 43.8295, kl/H(z|x): 0.0001, recon: 43.8293,time elapsed 1002.77s +kl weight 1.0000 +VAL --- avg_loss: 42.6173, kl/H(z|x): 0.0001, mi: 0.0351, recon: 42.6171, nll: 42.6173, ppl: 48.1488 +0 active units +TEST --- avg_loss: 42.6258, kl/H(z|x): 0.0001, mi: -0.1673, recon: 42.6256, nll: 42.6258, ppl: 48.1860 +TEST --- avg_loss: 42.5708, kl/H(z|x): 0.0001, mi: 0.0651, recon: 42.5707, nll: 42.5708, ppl: 47.9457 +0 active units +iw nll computing 00% +iw nll computing 10% +iw nll computing 20% +iw nll computing 30% +iw nll computing 40% +iw nll computing 50% +iw nll computing 60% +iw nll computing 70% +iw nll computing 80% +iw nll computing 90% +iw nll: 42.5245, iw ppl: 47.7445 diff --git a/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt new file mode 100644 index 0000000..6b00894 Binary files /dev/null and b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt differ diff --git a/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt new file mode 100644 index 0000000..f53c3c4 --- /dev/null +++ b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt @@ -0,0 +1,5 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=0.0, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 diff --git a/models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt new file mode 100644 index 0000000..6794a71 --- /dev/null +++ b/models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt @@ -0,0 +1,5 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=0.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 diff --git a/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt new file mode 100644 index 0000000..8cef7ec --- /dev/null +++ b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt @@ -0,0 +1,7 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=0.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=10) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 91.9735, kl/H(z|x): 15.9433, mi: 0.0662, recon: 76.0302,au 0, time elapsed 4.20s +epoch: 0, iter: 50, avg_loss: 81.7852, kl/H(z|x): 17.6312, mi: 3.5894, recon: 64.1539,au 32, time elapsed 9.39s diff --git a/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt new file mode 100644 index 0000000..8f8f059 Binary files /dev/null and b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt differ diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt new file mode 100644 index 0000000..f325a1c --- /dev/null +++ b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt @@ -0,0 +1,162 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 91.9838, kl/H(z|x): 15.9542, mi: 0.0727, recon: 76.0295,au 0, time elapsed 3.89s +gamma Parameter containing: +tensor([0.6361, 0.8692, 1.0447, 1.0875, 0.8834, 1.1141, 0.9036, 0.9981, 0.7376, + 1.0701, 1.0894, 0.8052, 1.0607, 0.9636, 0.8960, 1.0999, 1.2594, 1.0827, + 0.9545, 1.0415, 0.8483, 0.7461, 0.9029, 1.0743, 0.9385, 1.0999, 0.7905, + 1.1520, 1.0535, 1.0816, 1.3140, 1.0453], requires_grad=True) +train loc mean tensor([-0.0733, -0.0924, -0.0358, -0.0838, -0.0454, 0.0331, -0.0310, 0.0068, + -0.0078, 0.0136, -0.0392, -0.0524, 0.0494, 0.0408, -0.0388, -0.0561, + -0.0274, -0.0274, 0.0320, -0.1015, 0.0355, 0.0182, 0.0461, 0.0107, + 0.0683, -0.0382, -0.0352, -0.0434, 0.0076, 0.0446, -0.0125, -0.0292]) +train scale std tensor([7.7909e-05, 7.0655e-05, 8.6690e-05, 3.9981e-05, 5.8483e-05, 5.1343e-05, + 5.1193e-05, 7.1845e-05, 4.4821e-05, 4.0294e-05, 2.5077e-05, 5.2438e-05, + 3.7955e-05, 6.1363e-05, 5.4091e-05, 4.8194e-05, 7.3150e-05, 8.1629e-05, + 4.1456e-05, 6.3313e-05, 8.4339e-05, 2.3192e-05, 9.8620e-05, 2.6473e-05, + 6.0505e-05, 4.4957e-05, 3.8171e-05, 7.3833e-05, 3.8787e-05, 5.0260e-05, + 3.9497e-05, 5.0977e-05]) +train scale mean tensor([1.0283, 1.0284, 1.0290, 1.0281, 1.0284, 1.0283, 1.0283, 1.0281, 1.0282, + 1.0285, 1.0286, 1.0284, 1.0279, 1.0284, 1.0276, 1.0282, 1.0295, 1.0288, + 1.0292, 1.0284, 1.0287, 1.0285, 1.0282, 1.0289, 1.0293, 1.0279, 1.0287, + 1.0282, 1.0281, 1.0293, 1.0288, 1.0287]) +epoch: 0, iter: 50, avg_loss: 68.9229, kl/H(z|x): 1.8562, mi: -0.0067, recon: 67.0667,au 0, time elapsed 8.64s +epoch: 0, iter: 100, avg_loss: 61.1798, kl/H(z|x): 0.2321, mi: 0.0196, recon: 60.9478,au 0, time elapsed 13.72s +epoch: 0, iter: 150, avg_loss: 60.0701, kl/H(z|x): 0.3995, mi: -0.1306, recon: 59.6706,au 0, time elapsed 19.49s +epoch: 0, iter: 200, avg_loss: 57.2679, kl/H(z|x): 0.1905, mi: -0.0269, recon: 57.0774,au 0, time elapsed 24.24s +epoch: 0, iter: 250, avg_loss: 54.3758, kl/H(z|x): 0.1566, mi: -0.0273, recon: 54.2192,au 0, time elapsed 28.84s +epoch: 0, iter: 300, avg_loss: 52.8017, kl/H(z|x): 0.1288, mi: -0.0506, recon: 52.6729,au 0, time elapsed 33.81s +epoch: 0, iter: 350, avg_loss: 51.9674, kl/H(z|x): 0.1882, mi: 0.0187, recon: 51.7792,au 0, time elapsed 38.53s +epoch: 0, iter: 400, avg_loss: 51.5230, kl/H(z|x): 0.1526, mi: -0.0481, recon: 51.3704,au 0, time elapsed 43.18s +epoch: 0, iter: 450, avg_loss: 51.2951, kl/H(z|x): 0.1565, mi: 0.0256, recon: 51.1385,au 0, time elapsed 47.97s +kl weight 1.0000 +VAL --- avg_loss: 48.1727, kl/H(z|x): 0.1342, mi: -0.2018, recon: 48.0385, nll: 48.1727, ppl: 79.7848 +0 active units +update best loss +TEST --- avg_loss: 48.2219, kl/H(z|x): 0.1342, mi: -0.0120, recon: 48.0877, nll: 48.2219, ppl: 80.1426 +epoch: 1, iter: 500, avg_loss: 47.6891, kl/H(z|x): 0.1346, recon: 47.5546,time elapsed 62.33s +epoch: 1, iter: 550, avg_loss: 50.0957, kl/H(z|x): 0.1378, recon: 49.9579,time elapsed 63.77s +epoch: 1, iter: 600, avg_loss: 50.4986, kl/H(z|x): 0.1395, recon: 50.3591,time elapsed 65.03s +epoch: 1, iter: 650, avg_loss: 49.5509, kl/H(z|x): 0.1456, recon: 49.4053,time elapsed 66.35s +epoch: 1, iter: 700, avg_loss: 49.5276, kl/H(z|x): 0.2423, recon: 49.2853,time elapsed 67.58s +epoch: 1, iter: 750, avg_loss: 49.0258, kl/H(z|x): 0.1555, recon: 48.8703,time elapsed 69.08s +epoch: 1, iter: 800, avg_loss: 48.6905, kl/H(z|x): 0.1673, recon: 48.5232,time elapsed 70.86s +epoch: 1, iter: 850, avg_loss: 48.3193, kl/H(z|x): 0.1522, recon: 48.1671,time elapsed 72.32s +epoch: 1, iter: 900, avg_loss: 48.5319, kl/H(z|x): 0.1697, recon: 48.3622,time elapsed 73.62s +epoch: 1, iter: 950, avg_loss: 47.5075, kl/H(z|x): 0.1012, recon: 47.4063,time elapsed 74.94s +kl weight 1.0000 +VAL --- avg_loss: 45.2795, kl/H(z|x): 0.0738, mi: 0.0029, recon: 45.2057, nll: 45.2795, ppl: 61.3332 +0 active units +update best loss +TEST --- avg_loss: 45.2758, kl/H(z|x): 0.0738, mi: -0.0522, recon: 45.2020, nll: 45.2758, ppl: 61.3126 +epoch: 2, iter: 1000, avg_loss: 45.3208, kl/H(z|x): 0.0742, recon: 45.2465,time elapsed 87.48s +epoch: 2, iter: 1050, avg_loss: 47.0271, kl/H(z|x): 0.0944, recon: 46.9327,time elapsed 89.00s +epoch: 2, iter: 1100, avg_loss: 47.5806, kl/H(z|x): 0.2097, recon: 47.3709,time elapsed 91.38s +epoch: 2, iter: 1150, avg_loss: 47.6823, kl/H(z|x): 0.1700, recon: 47.5123,time elapsed 92.75s +epoch: 2, iter: 1200, avg_loss: 47.0332, kl/H(z|x): 0.1020, recon: 46.9313,time elapsed 93.91s +epoch: 2, iter: 1250, avg_loss: 47.2249, kl/H(z|x): 0.0633, recon: 47.1616,time elapsed 95.02s +epoch: 2, iter: 1300, avg_loss: 46.9512, kl/H(z|x): 0.1286, recon: 46.8226,time elapsed 96.24s +epoch: 2, iter: 1350, avg_loss: 46.9885, kl/H(z|x): 0.1035, recon: 46.8850,time elapsed 97.37s +epoch: 2, iter: 1400, avg_loss: 46.3206, kl/H(z|x): 0.1372, recon: 46.1834,time elapsed 98.49s +epoch: 2, iter: 1450, avg_loss: 47.1521, kl/H(z|x): 0.1887, recon: 46.9634,time elapsed 99.60s +kl weight 1.0000 +VAL --- avg_loss: 44.4193, kl/H(z|x): 0.0622, mi: 0.0397, recon: 44.3571, nll: 44.4193, ppl: 56.7194 +0 active units +update best loss +TEST --- avg_loss: 44.4166, kl/H(z|x): 0.0622, mi: -0.0181, recon: 44.3544, nll: 44.4166, ppl: 56.7056 +epoch: 3, iter: 1500, avg_loss: 45.3653, kl/H(z|x): 0.0619, recon: 45.3034,time elapsed 109.59s +epoch: 3, iter: 1550, avg_loss: 46.6098, kl/H(z|x): 0.0684, recon: 46.5414,time elapsed 110.68s +epoch: 3, iter: 1600, avg_loss: 46.6580, kl/H(z|x): 0.0872, recon: 46.5708,time elapsed 111.78s +epoch: 3, iter: 1650, avg_loss: 46.1828, kl/H(z|x): 0.0844, recon: 46.0984,time elapsed 113.02s +epoch: 3, iter: 1700, avg_loss: 46.3273, kl/H(z|x): 0.0641, recon: 46.2631,time elapsed 114.24s +epoch: 3, iter: 1750, avg_loss: 46.4725, kl/H(z|x): 0.0632, recon: 46.4093,time elapsed 115.38s +epoch: 3, iter: 1800, avg_loss: 46.6433, kl/H(z|x): 0.1000, recon: 46.5433,time elapsed 116.50s +epoch: 3, iter: 1850, avg_loss: 46.1040, kl/H(z|x): 0.0918, recon: 46.0121,time elapsed 117.60s +gamma Parameter containing: +tensor([0.7547, 0.8856, 1.1843, 1.0126, 1.0597, 0.9106, 0.9730, 1.1779, 0.7441, + 1.1986, 0.9274, 0.9449, 1.0565, 0.5328, 0.8193, 1.2728, 1.2570, 0.8743, + 0.9534, 0.7917, 0.8276, 0.8855, 1.0832, 1.0849, 1.1097, 1.1932, 0.7251, + 1.1920, 1.1917, 0.9867, 0.9388, 0.9593], requires_grad=True) +train loc mean tensor([ 0.0197, -0.0169, -0.0196, -0.0653, -0.0198, 0.0277, -0.0709, 0.0039, + 0.0091, -0.0372, -0.0042, -0.0662, 0.0050, 0.0271, 0.0054, 0.0648, + 0.0158, -0.1142, 0.0130, -0.0441, -0.0233, 0.0194, 0.0196, -0.0130, + 0.0155, 0.0078, 0.0079, 0.0202, 0.0312, -0.0052, -0.0383, -0.0199], + grad_fn=) +train scale std tensor([1.0685e-05, 8.4577e-06, 1.6077e-05, 1.5295e-05, 1.4061e-05, 1.4010e-05, + 2.2436e-05, 1.0087e-05, 5.1636e-06, 7.0330e-06, 5.6599e-06, 6.8874e-06, + 9.7419e-06, 1.8101e-05, 1.2978e-05, 1.5947e-05, 1.4863e-05, 5.4070e-06, + 3.5708e-06, 1.6960e-05, 1.7095e-05, 1.3345e-05, 1.7271e-05, 1.4926e-05, + 8.5931e-06, 1.1095e-05, 5.7805e-06, 1.9978e-05, 1.6887e-05, 5.9923e-06, + 8.9861e-06, 1.3255e-05], grad_fn=) +train scale mean tensor([1.0287, 1.0287, 1.0286, 1.0287, 1.0287, 1.0287, 1.0285, 1.0287, 1.0287, + 1.0287, 1.0288, 1.0288, 1.0287, 1.0286, 1.0287, 1.0286, 1.0287, 1.0288, + 1.0289, 1.0287, 1.0286, 1.0287, 1.0286, 1.0286, 1.0289, 1.0288, 1.0288, + 1.0285, 1.0286, 1.0288, 1.0288, 1.0285], grad_fn=) +epoch: 3, iter: 1900, avg_loss: 46.0540, kl/H(z|x): 0.0592, recon: 45.9948,time elapsed 118.72s +epoch: 3, iter: 1950, avg_loss: 45.6893, kl/H(z|x): 0.0609, recon: 45.6284,time elapsed 119.82s +kl weight 1.0000 +VAL --- avg_loss: 44.0410, kl/H(z|x): 0.0468, mi: 0.0664, recon: 43.9941, nll: 44.0410, ppl: 54.8018 +0 active units +update best loss +gamma Parameter containing: +tensor([0.7531, 0.8573, 1.1776, 1.0084, 1.0749, 0.9298, 0.9545, 1.1919, 0.7384, + 1.2042, 0.9418, 0.9320, 1.0016, 0.5333, 0.8417, 1.2837, 1.2379, 0.8844, + 0.9579, 0.7987, 0.8359, 0.8981, 1.0969, 1.0870, 1.1095, 1.2033, 0.7139, + 1.1977, 1.1889, 0.9906, 0.9035, 0.9766], requires_grad=True) +train loc mean tensor([ 0.0389, -0.0135, -0.0668, -0.0289, -0.0586, 0.0438, 0.0223, 0.0034, + 0.0052, 0.0265, 0.0018, 0.0417, -0.0163, 0.0010, 0.0236, -0.0516, + -0.0088, -0.0191, -0.0704, -0.0007, 0.0522, 0.0405, 0.0495, -0.0328, + -0.0161, -0.0143, 0.0084, -0.0512, 0.0116, 0.0380, -0.0292, -0.0483]) +train scale std tensor([8.3092e-06, 6.8119e-06, 1.7887e-05, 1.3828e-05, 1.8606e-05, 1.7448e-05, + 2.2577e-05, 3.3867e-06, 6.2293e-06, 4.3768e-06, 6.3535e-06, 7.3117e-06, + 4.1512e-06, 1.5505e-05, 8.7723e-06, 9.0478e-06, 2.0474e-05, 9.4471e-06, + 4.0762e-06, 6.4814e-06, 1.3078e-05, 1.9357e-05, 6.2588e-06, 1.1728e-05, + 7.7367e-06, 9.2317e-06, 7.4100e-06, 1.7443e-05, 1.0280e-05, 2.9517e-06, + 1.0600e-05, 9.2339e-06]) +train scale mean tensor([1.0288, 1.0287, 1.0287, 1.0287, 1.0287, 1.0286, 1.0286, 1.0288, 1.0287, + 1.0288, 1.0287, 1.0288, 1.0287, 1.0287, 1.0288, 1.0287, 1.0286, 1.0288, + 1.0289, 1.0288, 1.0287, 1.0287, 1.0288, 1.0287, 1.0288, 1.0288, 1.0288, + 1.0286, 1.0288, 1.0288, 1.0289, 1.0287]) +TEST --- avg_loss: 44.0432, kl/H(z|x): 0.0468, mi: -0.0153, recon: 43.9964, nll: 44.0432, ppl: 54.8129 +epoch: 4, iter: 2000, avg_loss: 43.9364, kl/H(z|x): 0.0473, recon: 43.8891,time elapsed 129.21s +epoch: 4, iter: 2050, avg_loss: 46.1033, kl/H(z|x): 0.1925, recon: 45.9108,time elapsed 130.29s +epoch: 4, iter: 2100, avg_loss: 46.2411, kl/H(z|x): 0.0922, recon: 46.1489,time elapsed 131.37s +epoch: 4, iter: 2150, avg_loss: 46.2292, kl/H(z|x): 0.0933, recon: 46.1359,time elapsed 132.48s +epoch: 4, iter: 2200, avg_loss: 46.3031, kl/H(z|x): 0.1519, recon: 46.1512,time elapsed 133.61s +epoch: 4, iter: 2250, avg_loss: 46.1080, kl/H(z|x): 0.1128, recon: 45.9952,time elapsed 134.71s +epoch: 4, iter: 2300, avg_loss: 46.2069, kl/H(z|x): 0.1018, recon: 46.1051,time elapsed 135.80s +epoch: 4, iter: 2350, avg_loss: 46.0232, kl/H(z|x): 0.1626, recon: 45.8606,time elapsed 136.88s +epoch: 4, iter: 2400, avg_loss: 46.4018, kl/H(z|x): 0.1171, recon: 46.2848,time elapsed 138.01s +epoch: 4, iter: 2450, avg_loss: 45.8789, kl/H(z|x): 0.0603, recon: 45.8185,time elapsed 139.11s +kl weight 1.0000 +VAL --- avg_loss: 43.9069, kl/H(z|x): 0.0898, mi: 0.0050, recon: 43.8171, nll: 43.9069, ppl: 54.1382 +0 active units +update best loss +gamma Parameter containing: +tensor([0.7509, 0.8669, 1.1856, 0.9956, 1.0628, 0.9358, 0.9456, 1.1958, 0.7394, + 1.2110, 0.9154, 0.9326, 1.0014, 0.5382, 0.8347, 1.2811, 1.2412, 0.8911, + 0.9606, 0.8028, 0.8493, 0.8925, 1.1058, 1.0843, 1.1013, 1.2081, 0.7034, + 1.2092, 1.1926, 0.9834, 0.9101, 0.9705], requires_grad=True) +train loc mean tensor([-0.0450, -0.0759, -0.0783, 0.0150, 0.1056, 0.1129, -0.0218, 0.0801, + -0.0061, 0.0265, 0.1380, -0.0496, -0.0869, -0.0271, -0.0230, -0.0508, + 0.0639, -0.0268, 0.0608, 0.0407, -0.0362, -0.0448, -0.0678, 0.0024, + -0.0354, 0.0236, 0.0404, -0.1080, 0.0892, 0.0413, -0.0047, -0.0542]) +train scale std tensor([1.6924e-05, 1.6882e-05, 1.0246e-05, 1.2540e-05, 4.4682e-06, 1.0566e-05, + 1.2307e-05, 1.0101e-05, 1.4016e-05, 2.1958e-05, 2.5869e-05, 1.4119e-05, + 1.4662e-05, 1.0661e-05, 1.0339e-05, 9.7715e-06, 1.3193e-05, 1.2174e-05, + 5.5699e-06, 2.0073e-05, 1.2144e-05, 1.1599e-05, 1.1315e-05, 1.2591e-05, + 2.3388e-05, 1.0639e-05, 1.2606e-05, 1.6476e-05, 1.3287e-05, 1.2702e-05, + 1.0028e-05, 1.5114e-05]) +train scale mean tensor([1.0286, 1.0285, 1.0287, 1.0287, 1.0288, 1.0287, 1.0287, 1.0286, 1.0285, + 1.0286, 1.0285, 1.0286, 1.0286, 1.0288, 1.0286, 1.0288, 1.0285, 1.0287, + 1.0289, 1.0286, 1.0287, 1.0288, 1.0287, 1.0287, 1.0284, 1.0286, 1.0288, + 1.0286, 1.0285, 1.0287, 1.0286, 1.0285]) +TEST --- avg_loss: 43.9080, kl/H(z|x): 0.0898, mi: 0.0293, recon: 43.8182, nll: 43.9080, ppl: 54.1436 +epoch: 5, iter: 2500, avg_loss: 46.1478, kl/H(z|x): 0.0773, recon: 46.0705,time elapsed 148.53s +epoch: 5, iter: 2550, avg_loss: 45.7327, kl/H(z|x): 0.0603, recon: 45.6724,time elapsed 149.88s +epoch: 5, iter: 2600, avg_loss: 45.5843, kl/H(z|x): 0.0852, recon: 45.4991,time elapsed 151.01s +epoch: 5, iter: 2650, avg_loss: 45.9211, kl/H(z|x): 0.0895, recon: 45.8317,time elapsed 152.14s +epoch: 5, iter: 2700, avg_loss: 45.5345, kl/H(z|x): 0.0683, recon: 45.4662,time elapsed 153.30s diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt new file mode 100644 index 0000000..3fe6613 Binary files /dev/null and b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt differ diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt new file mode 100644 index 0000000..20aafca --- /dev/null +++ b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt @@ -0,0 +1,77 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.2, reset_dec=False, save_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 98.1471, kl/H(z|x): 22.1174, mi: 0.0398, recon: 76.0297,au 0, time elapsed 3.80s +epoch: 0, iter: 50, avg_loss: 75.2584, kl/H(z|x): 8.0159, mi: -0.0773, recon: 67.2426,au 0, time elapsed 8.56s +epoch: 0, iter: 100, avg_loss: 67.1634, kl/H(z|x): 6.2421, mi: -0.0161, recon: 60.9213,au 0, time elapsed 13.36s +epoch: 0, iter: 150, avg_loss: 65.9610, kl/H(z|x): 6.4145, mi: 0.6229, recon: 59.5465,au 26, time elapsed 20.04s +epoch: 0, iter: 200, avg_loss: 64.0590, kl/H(z|x): 6.3462, mi: 0.0654, recon: 57.7129,au 7, time elapsed 25.26s +epoch: 0, iter: 250, avg_loss: 60.5662, kl/H(z|x): 6.2482, mi: -0.0251, recon: 54.3180,au 0, time elapsed 30.56s +epoch: 0, iter: 300, avg_loss: 58.9862, kl/H(z|x): 6.2296, mi: -0.0323, recon: 52.7566,au 0, time elapsed 35.99s +epoch: 0, iter: 350, avg_loss: 57.9861, kl/H(z|x): 6.2317, mi: 0.0645, recon: 51.7544,au 0, time elapsed 41.22s +epoch: 0, iter: 400, avg_loss: 57.7159, kl/H(z|x): 6.2682, mi: -0.0005, recon: 51.4477,au 0, time elapsed 46.25s +epoch: 0, iter: 450, avg_loss: 56.8888, kl/H(z|x): 6.1970, mi: 0.0484, recon: 50.6919,au 0, time elapsed 53.38s +kl weight 1.0000 +VAL --- avg_loss: 69.2053, kl/H(z|x): 18.6423, mi: -0.1538, recon: 50.5630, nll: 69.2053, ppl: 539.9051 +0 active units +update best loss +TEST --- avg_loss: 69.2140, kl/H(z|x): 18.6423, mi: -0.0284, recon: 50.5717, nll: 69.2140, ppl: 540.3313 +epoch: 1, iter: 500, avg_loss: 55.1085, kl/H(z|x): 6.3617, recon: 48.7468,time elapsed 62.84s +epoch: 1, iter: 550, avg_loss: 55.3878, kl/H(z|x): 6.1804, recon: 49.2074,time elapsed 64.00s +epoch: 1, iter: 600, avg_loss: 55.1184, kl/H(z|x): 6.2019, recon: 48.9165,time elapsed 65.20s +epoch: 1, iter: 650, avg_loss: 54.3091, kl/H(z|x): 6.2124, recon: 48.0967,time elapsed 66.34s +epoch: 1, iter: 700, avg_loss: 54.2357, kl/H(z|x): 6.5016, recon: 47.7341,time elapsed 67.88s +epoch: 1, iter: 750, avg_loss: 53.9039, kl/H(z|x): 6.1897, recon: 47.7142,time elapsed 69.46s +epoch: 1, iter: 800, avg_loss: 53.8167, kl/H(z|x): 6.2032, recon: 47.6135,time elapsed 70.90s +epoch: 1, iter: 850, avg_loss: 53.5023, kl/H(z|x): 6.2023, recon: 47.2999,time elapsed 73.23s +epoch: 1, iter: 900, avg_loss: 53.8856, kl/H(z|x): 6.2500, recon: 47.6356,time elapsed 75.30s +epoch: 1, iter: 950, avg_loss: 53.3251, kl/H(z|x): 6.2294, recon: 47.0958,time elapsed 76.53s +kl weight 1.0000 +VAL --- avg_loss: 49.8829, kl/H(z|x): 3.8641, mi: 0.0947, recon: 46.0188, nll: 49.8829, ppl: 93.2058 +0 active units +update best loss +TEST --- avg_loss: 49.8833, kl/H(z|x): 3.8641, mi: -0.0544, recon: 46.0192, nll: 49.8833, ppl: 93.2088 +epoch: 2, iter: 1000, avg_loss: 53.2816, kl/H(z|x): 5.3442, recon: 47.9374,time elapsed 91.88s +epoch: 2, iter: 1050, avg_loss: 52.7176, kl/H(z|x): 6.2500, recon: 46.4677,time elapsed 93.11s +epoch: 2, iter: 1100, avg_loss: 53.0183, kl/H(z|x): 6.2678, recon: 46.7505,time elapsed 94.73s +epoch: 2, iter: 1150, avg_loss: 53.2235, kl/H(z|x): 6.2141, recon: 47.0094,time elapsed 95.95s +epoch: 2, iter: 1200, avg_loss: 52.6633, kl/H(z|x): 6.0780, recon: 46.5854,time elapsed 99.24s +epoch: 2, iter: 1250, avg_loss: 52.9677, kl/H(z|x): 6.1585, recon: 46.8092,time elapsed 100.49s +epoch: 2, iter: 1300, avg_loss: 52.5459, kl/H(z|x): 6.1299, recon: 46.4160,time elapsed 101.89s +epoch: 2, iter: 1350, avg_loss: 52.6617, kl/H(z|x): 6.0985, recon: 46.5632,time elapsed 103.14s +epoch: 2, iter: 1400, avg_loss: 52.0059, kl/H(z|x): 6.0843, recon: 45.9216,time elapsed 104.36s +epoch: 2, iter: 1450, avg_loss: 52.3689, kl/H(z|x): 6.1425, recon: 46.2264,time elapsed 105.60s +kl weight 1.0000 +VAL --- avg_loss: 45.3509, kl/H(z|x): 1.3035, mi: 0.0306, recon: 44.0474, nll: 45.3509, ppl: 61.7324 +0 active units +update best loss +TEST --- avg_loss: 45.3404, kl/H(z|x): 1.3035, mi: -0.0230, recon: 44.0370, nll: 45.3404, ppl: 61.6738 +epoch: 3, iter: 1500, avg_loss: 49.9944, kl/H(z|x): 5.5341, recon: 44.4602,time elapsed 117.11s +epoch: 3, iter: 1550, avg_loss: 52.2121, kl/H(z|x): 6.0559, recon: 46.1563,time elapsed 119.99s +epoch: 3, iter: 1600, avg_loss: 52.0304, kl/H(z|x): 6.0492, recon: 45.9812,time elapsed 121.21s +epoch: 3, iter: 1650, avg_loss: 52.0216, kl/H(z|x): 6.1220, recon: 45.8996,time elapsed 122.39s +epoch: 3, iter: 1700, avg_loss: 52.1720, kl/H(z|x): 6.1192, recon: 46.0528,time elapsed 123.50s +epoch: 3, iter: 1750, avg_loss: 52.4309, kl/H(z|x): 6.0886, recon: 46.3423,time elapsed 124.60s +epoch: 3, iter: 1800, avg_loss: 52.6185, kl/H(z|x): 6.3136, recon: 46.3049,time elapsed 125.72s +epoch: 3, iter: 1850, avg_loss: 51.9571, kl/H(z|x): 6.1577, recon: 45.7994,time elapsed 126.91s +epoch: 3, iter: 1900, avg_loss: 51.9346, kl/H(z|x): 6.1612, recon: 45.7734,time elapsed 128.86s +epoch: 3, iter: 1950, avg_loss: 51.6526, kl/H(z|x): 6.1017, recon: 45.5510,time elapsed 131.37s +kl weight 1.0000 +VAL --- avg_loss: 44.7754, kl/H(z|x): 0.8472, mi: 0.0321, recon: 43.9282, nll: 44.7754, ppl: 58.5859 +0 active units +update best loss +TEST --- avg_loss: 44.7674, kl/H(z|x): 0.8472, mi: -0.0025, recon: 43.9202, nll: 44.7674, ppl: 58.5432 +epoch: 4, iter: 2000, avg_loss: 50.4556, kl/H(z|x): 6.1846, recon: 44.2711,time elapsed 142.19s +epoch: 4, iter: 2050, avg_loss: 51.7029, kl/H(z|x): 6.2001, recon: 45.5028,time elapsed 143.48s +epoch: 4, iter: 2100, avg_loss: 52.0093, kl/H(z|x): 6.0755, recon: 45.9338,time elapsed 144.65s +epoch: 4, iter: 2150, avg_loss: 51.8900, kl/H(z|x): 6.0952, recon: 45.7948,time elapsed 145.75s +epoch: 4, iter: 2200, avg_loss: 51.6964, kl/H(z|x): 6.0536, recon: 45.6428,time elapsed 146.84s +epoch: 4, iter: 2250, avg_loss: 51.6489, kl/H(z|x): 6.1424, recon: 45.5065,time elapsed 148.00s +epoch: 4, iter: 2300, avg_loss: 51.8192, kl/H(z|x): 6.0458, recon: 45.7734,time elapsed 149.10s +epoch: 4, iter: 2350, avg_loss: 51.3648, kl/H(z|x): 5.9993, recon: 45.3655,time elapsed 150.19s +epoch: 4, iter: 2400, avg_loss: 52.3250, kl/H(z|x): 6.2564, recon: 46.0686,time elapsed 151.32s +epoch: 4, iter: 2450, avg_loss: 51.6578, kl/H(z|x): 6.0313, recon: 45.6266,time elapsed 152.43s +kl weight 1.0000 +VAL --- avg_loss: 44.6217, kl/H(z|x): 0.7541, mi: 0.1146, recon: 43.8676, nll: 44.6217, ppl: 57.7731 diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt new file mode 100644 index 0000000..6abdeed Binary files /dev/null and b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt differ diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt new file mode 100644 index 0000000..af8a5a9 --- /dev/null +++ b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt @@ -0,0 +1,10 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 109.8254, kl/H(z|x): 33.7955, mi: 0.0425, recon: 76.0299,au 0, time elapsed 3.76s +epoch: 0, iter: 50, avg_loss: 83.8892, kl/H(z|x): 16.9853, mi: -0.0774, recon: 66.9040,au 0, time elapsed 8.47s +epoch: 0, iter: 100, avg_loss: 75.9278, kl/H(z|x): 15.2441, mi: -0.0223, recon: 60.6837,au 0, time elapsed 13.28s +epoch: 0, iter: 150, avg_loss: 74.6431, kl/H(z|x): 15.2945, mi: 0.5371, recon: 59.3485,au 20, time elapsed 17.94s +epoch: 0, iter: 200, avg_loss: 72.5163, kl/H(z|x): 15.3962, mi: -0.0651, recon: 57.1201,au 0, time elapsed 22.74s diff --git a/models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt new file mode 100644 index 0000000..a5948a4 --- /dev/null +++ b/models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt @@ -0,0 +1,9 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 89.9236, kl/H(z|x): 13.8949, mi: 0.0091, recon: 76.0287,au 0, time elapsed 3.66s +epoch: 0, iter: 50, avg_loss: 75.8018, kl/H(z|x): 9.7718, mi: 0.0997, recon: 66.0301,au 0, time elapsed 8.55s +epoch: 0, iter: 100, avg_loss: 69.6457, kl/H(z|x): 9.1719, mi: -0.0288, recon: 60.4737,au 0, time elapsed 13.30s +epoch: 0, iter: 150, avg_loss: 68.5756, kl/H(z|x): 9.0761, mi: -0.0296, recon: 59.4994,au 0, time elapsed 18.10s diff --git a/models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt new file mode 100644 index 0000000..4f8fdca --- /dev/null +++ b/models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt @@ -0,0 +1,9 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.15, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 89.9236, kl/H(z|x): 13.8949, mi: 0.0091, recon: 76.0287,au 0, time elapsed 5.64s +epoch: 0, iter: 50, avg_loss: 76.0595, kl/H(z|x): 9.8928, mi: 0.0997, recon: 66.1668,au 0, time elapsed 10.54s +epoch: 0, iter: 100, avg_loss: 69.7483, kl/H(z|x): 9.2888, mi: -0.0287, recon: 60.4595,au 0, time elapsed 15.83s +epoch: 0, iter: 150, avg_loss: 68.6088, kl/H(z|x): 9.1824, mi: -0.0296, recon: 59.4264,au 0, time elapsed 20.88s diff --git a/models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt new file mode 100644 index 0000000..6fb79c7 --- /dev/null +++ b/models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt @@ -0,0 +1,11 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=0.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.0, gamma_train=False, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, save_path='models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 84.0465, kl/H(z|x): 8.0181, mi: -0.0320, recon: 76.0285,au 0, time elapsed 3.85s +epoch: 0, iter: 50, avg_loss: 67.9761, kl/H(z|x): 1.1939, mi: 0.0944, recon: 66.7822,au 0, time elapsed 8.59s +epoch: 0, iter: 100, avg_loss: 60.7799, kl/H(z|x): 0.1675, mi: -0.0444, recon: 60.6124,au 0, time elapsed 13.57s +epoch: 0, iter: 150, avg_loss: 59.6752, kl/H(z|x): 0.1535, mi: -0.0352, recon: 59.5216,au 0, time elapsed 18.29s +epoch: 0, iter: 200, avg_loss: 56.9132, kl/H(z|x): 0.1076, mi: 0.0690, recon: 56.8057,au 0, time elapsed 23.04s +epoch: 0, iter: 250, avg_loss: 54.5933, kl/H(z|x): 0.1175, mi: 0.0078, recon: 54.4758,au 0, time elapsed 27.81s diff --git a/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt new file mode 100644 index 0000000..011c7b5 --- /dev/null +++ b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt @@ -0,0 +1,7 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.8, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 90.0487, kl/H(z|x): 14.0200, mi: -0.0322, recon: 76.0287,au 0, time elapsed 3.93s +epoch: 0, iter: 50, avg_loss: 81.0231, kl/H(z|x): 14.3386, mi: 0.7551, recon: 66.6845,au 12, time elapsed 8.63s diff --git a/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt new file mode 100644 index 0000000..4d782c9 Binary files /dev/null and b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt differ diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt new file mode 100644 index 0000000..4ed17a2 --- /dev/null +++ b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt @@ -0,0 +1,24 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=0.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.8, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 84.2557, kl/H(z|x): 8.2271, mi: -0.0313, recon: 76.0286,au 0, time elapsed 3.82s +epoch: 0, iter: 50, avg_loss: 70.6074, kl/H(z|x): 4.2926, mi: 0.7366, recon: 66.3148,au 22, time elapsed 9.52s +epoch: 0, iter: 100, avg_loss: 62.1741, kl/H(z|x): 2.6194, mi: 0.6129, recon: 59.5546,au 23, time elapsed 14.88s +epoch: 0, iter: 150, avg_loss: 60.9270, kl/H(z|x): 2.6617, mi: 0.6679, recon: 58.2653,au 23, time elapsed 20.14s +epoch: 0, iter: 200, avg_loss: 58.7194, kl/H(z|x): 2.2669, mi: 0.6364, recon: 56.4524,au 24, time elapsed 25.00s +epoch: 0, iter: 250, avg_loss: 55.7990, kl/H(z|x): 1.8970, mi: 0.4772, recon: 53.9020,au 20, time elapsed 29.83s +epoch: 0, iter: 300, avg_loss: 54.3986, kl/H(z|x): 1.9483, mi: 0.3116, recon: 52.4503,au 19, time elapsed 34.72s +epoch: 0, iter: 350, avg_loss: 53.3452, kl/H(z|x): 1.8530, mi: 0.5277, recon: 51.4922,au 19, time elapsed 39.95s +epoch: 0, iter: 400, avg_loss: 53.2150, kl/H(z|x): 1.7295, mi: 0.2439, recon: 51.4855,au 16, time elapsed 45.09s +epoch: 0, iter: 450, avg_loss: 52.9093, kl/H(z|x): 1.8259, mi: 0.3234, recon: 51.0834,au 18, time elapsed 50.17s +kl weight 1.0000 +VAL --- avg_loss: 48.8077, kl/H(z|x): 1.2442, mi: 0.2892, recon: 47.5353, nll: 48.7795, ppl: 84.3100 +20 active units +update best loss +TEST --- avg_loss: 48.8295, kl/H(z|x): 1.2387, mi: 0.2400, recon: 47.5826, nll: 48.8212, ppl: 84.6304 +epoch: 1, iter: 500, avg_loss: 47.8222, kl/H(z|x): 1.2917, recon: 46.5305,time elapsed 60.22s +epoch: 1, iter: 550, avg_loss: 51.4242, kl/H(z|x): 1.7375, recon: 49.6867,time elapsed 61.58s +epoch: 1, iter: 600, avg_loss: 51.8762, kl/H(z|x): 1.6511, recon: 50.2251,time elapsed 64.43s +epoch: 1, iter: 650, avg_loss: 50.9553, kl/H(z|x): 1.6560, recon: 49.2993,time elapsed 66.71s diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt new file mode 100644 index 0000000..a1cabc4 Binary files /dev/null and b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt differ diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt new file mode 100644 index 0000000..a874cc7 --- /dev/null +++ b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt @@ -0,0 +1,28 @@ +Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.8, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100) +data/synthetic_data/vocab.txt +Train data: 16000 samples +finish reading datasets, vocab size is 1004 +dropped sentences: 0 +epoch: 0, iter: 0, avg_loss: 90.0487, kl/H(z|x): 14.0200, mi: 0.0092, recon: 76.0288,au 0, time elapsed 3.77s +epoch: 0, iter: 50, avg_loss: 80.6753, kl/H(z|x): 14.6578, mi: 0.8141, recon: 66.0176,au 18, time elapsed 8.55s +epoch: 0, iter: 100, avg_loss: 72.3764, kl/H(z|x): 13.0901, mi: 1.0537, recon: 59.2863,au 19, time elapsed 13.41s +epoch: 0, iter: 150, avg_loss: 71.4821, kl/H(z|x): 13.4491, mi: 0.6681, recon: 58.0330,au 19, time elapsed 18.23s +epoch: 0, iter: 200, avg_loss: 68.6188, kl/H(z|x): 12.5864, mi: 0.6411, recon: 56.0324,au 18, time elapsed 23.76s +epoch: 0, iter: 250, avg_loss: 65.7556, kl/H(z|x): 12.0306, mi: 0.8201, recon: 53.7249,au 17, time elapsed 29.69s +epoch: 0, iter: 300, avg_loss: 64.6366, kl/H(z|x): 12.2975, mi: 0.6563, recon: 52.3392,au 21, time elapsed 34.73s +epoch: 0, iter: 350, avg_loss: 63.6333, kl/H(z|x): 12.1879, mi: 0.8044, recon: 51.4454,au 21, time elapsed 39.77s +epoch: 0, iter: 400, avg_loss: 63.1732, kl/H(z|x): 12.0782, mi: 0.5898, recon: 51.0949,au 15, time elapsed 44.60s +epoch: 0, iter: 450, avg_loss: 62.4059, kl/H(z|x): 11.7999, mi: 0.6608, recon: 50.6061,au 15, time elapsed 50.74s +kl weight 1.0000 +VAL --- avg_loss: 51.0809, kl/H(z|x): 2.9957, mi: 0.6483, recon: 48.0950, nll: 51.0907, ppl: 104.0226 +22 active units +update best loss +TEST --- avg_loss: 51.0537, kl/H(z|x): 2.9736, mi: 0.6390, recon: 48.0735, nll: 51.0471, ppl: 103.6111 +epoch: 1, iter: 500, avg_loss: 59.4675, kl/H(z|x): 11.8408, recon: 47.6266,time elapsed 61.07s +epoch: 1, iter: 550, avg_loss: 60.6828, kl/H(z|x): 11.7920, recon: 48.8908,time elapsed 62.34s +epoch: 1, iter: 600, avg_loss: 60.5349, kl/H(z|x): 11.8571, recon: 48.6779,time elapsed 63.65s +epoch: 1, iter: 650, avg_loss: 59.8986, kl/H(z|x): 11.9386, recon: 47.9600,time elapsed 64.92s +epoch: 1, iter: 700, avg_loss: 59.4865, kl/H(z|x): 11.8225, recon: 47.6640,time elapsed 66.25s +epoch: 1, iter: 750, avg_loss: 59.0838, kl/H(z|x): 11.7320, recon: 47.3517,time elapsed 67.57s +epoch: 1, iter: 800, avg_loss: 58.9415, kl/H(z|x): 11.8043, recon: 47.1373,time elapsed 68.85s +epoch: 1, iter: 850, avg_loss: 58.6082, kl/H(z|x): 11.6145, recon: 46.9938,time elapsed 70.12s diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt new file mode 100644 index 0000000..c7c69b7 Binary files /dev/null and b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt differ diff --git a/modules/.DS_Store b/modules/.DS_Store new file mode 100644 index 0000000..78526d3 Binary files /dev/null and b/modules/.DS_Store differ diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100755 index 0000000..62f9459 --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,6 @@ +from .encoders import * +from .decoders import * +from .vae import * +from .vae_IAF import * +from .utils import * +from .discriminators import * \ No newline at end of file diff --git a/modules/decoders/__init__.py b/modules/decoders/__init__.py new file mode 100755 index 0000000..eb2c937 --- /dev/null +++ b/modules/decoders/__init__.py @@ -0,0 +1,3 @@ +from .dec_lstm import * +from .dec_pixelcnn import * +from .dec_pixelcnn_v2 import * \ No newline at end of file diff --git a/modules/decoders/dec_lstm.py b/modules/decoders/dec_lstm.py new file mode 100755 index 0000000..c9fc172 --- /dev/null +++ b/modules/decoders/dec_lstm.py @@ -0,0 +1,477 @@ +# import torch + +import time +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence + +import numpy as np + +from .decoder import DecoderBase +from .decoder_helper import BeamSearchNode + +class LSTMDecoder(DecoderBase): + """LSTM decoder with constant-length batching""" + def __init__(self, args, vocab, model_init, emb_init): + super(LSTMDecoder, self).__init__() + self.ni = args.ni + self.nh = args.dec_nh + self.nz = args.nz + self.vocab = vocab + self.device = args.device + + # no padding when setting padding_idx to -1 + self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1) + + self.dropout_in = nn.Dropout(args.dec_dropout_in) + self.dropout_out = nn.Dropout(args.dec_dropout_out) + + # for initializing hidden state and cell + self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False) + + # concatenate z with input + self.lstm = nn.LSTM(input_size=args.ni + args.nz, + hidden_size=args.dec_nh, + num_layers=1, + batch_first=True) + + # prediction layer + self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False) + + vocab_mask = torch.ones(len(vocab)) + # vocab_mask[vocab['']] = 0 + self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) + + self.reset_parameters(model_init, emb_init) + + def reset_parameters(self, model_init, emb_init): + # for name, param in self.lstm.named_parameters(): + # # self.initializer(param) + # if 'bias' in name: + # nn.init.constant_(param, 0.0) + # # model_init(param) + # elif 'weight' in name: + # model_init(param) + + # model_init(self.trans_linear.weight) + # model_init(self.pred_linear.weight) + for param in self.parameters(): + model_init(param) + emb_init(self.embed.weight) + + def decode(self, input, z): + """ + Args: + input: (batch_size, seq_len) + z: (batch_size, n_sample, nz) + """ + + # not predicting start symbol + # sents_len -= 1 + + batch_size, n_sample, _ = z.size() + seq_len = input.size(1) + + # (batch_size, seq_len, ni) + word_embed = self.embed(input) + word_embed = self.dropout_in(word_embed) + + if n_sample == 1: + z_ = z.expand(batch_size, seq_len, self.nz) + + else: + word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ + .contiguous() + + # (batch_size * n_sample, seq_len, ni) + word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) + + z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() + z_ = z_.view(batch_size * n_sample, seq_len, self.nz) + + # (batch_size * n_sample, seq_len, ni + nz) + word_embed = torch.cat((word_embed, z_), -1) + + z = z.view(batch_size * n_sample, self.nz) + c_init = self.trans_linear(z).unsqueeze(0) + h_init = torch.tanh(c_init) + # h_init = self.trans_linear(z).unsqueeze(0) + # c_init = h_init.new_zeros(h_init.size()) + output, _ = self.lstm(word_embed, (h_init, c_init)) + + output = self.dropout_out(output) + + # (batch_size * n_sample, seq_len, vocab_size) + output_logits = self.pred_linear(output) + + return output_logits + + def reconstruct_error(self, x, z): + """Cross Entropy in the language case + Args: + x: (batch_size, seq_len) + z: (batch_size, n_sample, nz) + Returns: + loss: (batch_size, n_sample). Loss + across different sentence and z + """ + + #remove end symbol + src = x[:, :-1] + + # remove start symbol + tgt = x[:, 1:] + + batch_size, seq_len = src.size() + n_sample = z.size(1) + + # (batch_size * n_sample, seq_len, vocab_size) + output_logits = self.decode(src, z) + + if n_sample == 1: + tgt = tgt.contiguous().view(-1) + else: + # (batch_size * n_sample * seq_len) + tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ + .contiguous().view(-1) + + # (batch_size * n_sample * seq_len) + loss = self.loss(output_logits.view(-1, output_logits.size(2)), + tgt) + + + # (batch_size, n_sample) + return loss.view(batch_size, n_sample, -1).sum(-1) + + + def log_probability(self, x, z): + """Cross Entropy in the language case + Args: + x: (batch_size, seq_len) + z: (batch_size, n_sample, nz) + Returns: + log_p: (batch_size, n_sample). + log_p(x|z) across different x and z + """ + + return -self.reconstruct_error(x, z) + + def beam_search_decode(self, z, K=5): + """beam search decoding, code is based on + https://github.com/pcyin/pytorch_basic_nmt/blob/master/nmt.py + + the current implementation decodes sentence one by one, further batching would improve the speed + + Args: + z: (batch_size, nz) + K: the beam width + + Returns: List1 + List1: the decoded word sentence list + """ + + decoded_batch = [] + batch_size, nz = z.size() + + # (1, batch_size, nz) + c_init = self.trans_linear(z).unsqueeze(0) + h_init = torch.tanh(c_init) + + # decoding goes sentence by sentence + for idx in range(batch_size): + # Start with the start of the sentence token + decoder_input = torch.tensor([[self.vocab[""]]], dtype=torch.long, device=self.device) + decoder_hidden = (h_init[:,idx,:].unsqueeze(1), c_init[:,idx,:].unsqueeze(1)) + + node = BeamSearchNode(decoder_hidden, None, decoder_input, 0., 1) + live_hypotheses = [node] + + completed_hypotheses = [] + + t = 0 + while len(completed_hypotheses) < K and t < 100: + t += 1 + + # (len(live), 1) + decoder_input = torch.cat([node.wordid for node in live_hypotheses], dim=0) + + # (1, len(live), nh) + decoder_hidden_h = torch.cat([node.h[0] for node in live_hypotheses], dim=1) + decoder_hidden_c = torch.cat([node.h[1] for node in live_hypotheses], dim=1) + + decoder_hidden = (decoder_hidden_h, decoder_hidden_c) + + + # (len(live), 1, ni) --> (len(live), 1, ni+nz) + word_embed = self.embed(decoder_input) + word_embed = torch.cat((word_embed, z[idx].view(1, 1, -1).expand( + len(live_hypotheses), 1, nz)), dim=-1) + + output, decoder_hidden = self.lstm(word_embed, decoder_hidden) + + # (len(live), 1, vocab_size) + output_logits = self.pred_linear(output) + decoder_output = F.log_softmax(output_logits, dim=-1) + + prev_logp = torch.tensor([node.logp for node in live_hypotheses], dtype=torch.float, device=self.device) + decoder_output = decoder_output + prev_logp.view(len(live_hypotheses), 1, 1) + + # (len(live) * vocab_size) + decoder_output = decoder_output.view(-1) + + # (K) + log_prob, indexes = torch.topk(decoder_output, K-len(completed_hypotheses)) + + live_ids = indexes // len(self.vocab) + word_ids = indexes % len(self.vocab) + + live_hypotheses_new = [] + for live_id, word_id, log_prob_ in zip(live_ids, word_ids, log_prob): + node = BeamSearchNode((decoder_hidden[0][:, live_id, :].unsqueeze(1), + decoder_hidden[1][:, live_id, :].unsqueeze(1)), + live_hypotheses[live_id], word_id.view(1, 1), log_prob_, t) + + if word_id.item() == self.vocab[""]: + completed_hypotheses.append(node) + else: + live_hypotheses_new.append(node) + + live_hypotheses = live_hypotheses_new + + if len(completed_hypotheses) == K: + break + + for live in live_hypotheses: + completed_hypotheses.append(live) + + utterances = [] + for n in sorted(completed_hypotheses, key=lambda node: node.logp, reverse=True): + utterance = [] + utterance.append(self.vocab.id2word(n.wordid.item())) + # back trace + while n.prevNode != None: + n = n.prevNode + utterance.append(self.vocab.id2word(n.wordid.item())) + + utterance = utterance[::-1] + + utterances.append(utterance) + + # only save the top 1 + break + + decoded_batch.append(utterances[0]) + + return decoded_batch + + def greedy_decode(self, z): + """greedy decoding from z + Args: + z: (batch_size, nz) + + Returns: List1 + List1: the decoded word sentence list + """ + + batch_size = z.size(0) + decoded_batch = [[] for _ in range(batch_size)] + + # (batch_size, 1, nz) + c_init = self.trans_linear(z).unsqueeze(0) + h_init = torch.tanh(c_init) + + decoder_hidden = (h_init, c_init) + decoder_input = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1) + end_symbol = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device) + + mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device) + length_c = 1 + while mask.sum().item() != 0 and length_c < 100: + + # (batch_size, 1, ni) --> (batch_size, 1, ni+nz) + word_embed = self.embed(decoder_input) + word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1) + + output, decoder_hidden = self.lstm(word_embed, decoder_hidden) + + # (batch_size, 1, vocab_size) --> (batch_size, vocab_size) + decoder_output = self.pred_linear(output) + output_logits = decoder_output.squeeze(1) + + # (batch_size) + max_index = torch.argmax(output_logits, dim=1) + # max_index = torch.multinomial(probs, num_samples=1) + + decoder_input = max_index.unsqueeze(1) + length_c += 1 + + for i in range(batch_size): + if mask[i].item(): + decoded_batch[i].append(self.vocab.id2word(max_index[i].item())) + + mask = torch.mul((max_index != end_symbol), mask) + + return decoded_batch + + def sample_decode(self, z): + """sampling decoding from z + Args: + z: (batch_size, nz) + + Returns: List1 + List1: the decoded word sentence list + """ + + batch_size = z.size(0) + decoded_batch = [[] for _ in range(batch_size)] + + # (batch_size, 1, nz) + c_init = self.trans_linear(z).unsqueeze(0) + h_init = torch.tanh(c_init) + + decoder_hidden = (h_init, c_init) + decoder_input = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1) + end_symbol = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device) + + mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device) + length_c = 1 + while mask.sum().item() != 0 and length_c < 100: + + # (batch_size, 1, ni) --> (batch_size, 1, ni+nz) + word_embed = self.embed(decoder_input) + word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1) + + output, decoder_hidden = self.lstm(word_embed, decoder_hidden) + + # (batch_size, 1, vocab_size) --> (batch_size, vocab_size) + decoder_output = self.pred_linear(output) + output_logits = decoder_output.squeeze(1) + + # (batch_size) + sample_prob = F.softmax(output_logits, dim=1) + sample_index = torch.multinomial(sample_prob, num_samples=1).squeeze(1) + + decoder_input = sample_index.unsqueeze(1) + length_c += 1 + + for i in range(batch_size): + if mask[i].item(): + decoded_batch[i].append(self.vocab.id2word(sample_index[i].item())) + + mask = torch.mul((sample_index != end_symbol), mask) + + return decoded_batch + + +class VarLSTMDecoder(LSTMDecoder): + """LSTM decoder with variable-length batching""" + def __init__(self, args, vocab, model_init, emb_init): + super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init) + + self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['']) + vocab_mask = torch.ones(len(vocab)) + vocab_mask[vocab['']] = 0 + self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) + + self.reset_parameters(model_init, emb_init) + + def decode(self, input, z): + """ + Args: + input: tuple which contains x and sents_len + x: (batch_size, seq_len) + sents_len: long tensor of sentence lengths + z: (batch_size, n_sample, nz) + """ + + input, sents_len = input + + # not predicting start symbol + sents_len = sents_len - 1 + + batch_size, n_sample, _ = z.size() + seq_len = input.size(1) + + # (batch_size, seq_len, ni) + word_embed = self.embed(input) + word_embed = self.dropout_in(word_embed) + + if n_sample == 1: + z_ = z.expand(batch_size, seq_len, self.nz) + + else: + word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ + .contiguous() + + # (batch_size * n_sample, seq_len, ni) + word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) + + z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() + z_ = z_.view(batch_size * n_sample, seq_len, self.nz) + + # (batch_size * n_sample, seq_len, ni + nz) + word_embed = torch.cat((word_embed, z_), -1) + + sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1) + packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) + + z = z.view(batch_size * n_sample, self.nz) + # h_init = self.trans_linear(z).unsqueeze(0) + # c_init = h_init.new_zeros(h_init.size()) + c_init = self.trans_linear(z).unsqueeze(0) + h_init = torch.tanh(c_init) + output, _ = self.lstm(packed_embed, (h_init, c_init)) + output, _ = pad_packed_sequence(output, batch_first=True) + + output = self.dropout_out(output) + + # (batch_size * n_sample, seq_len, vocab_size) + output_logits = self.pred_linear(output) + + return output_logits + + def reconstruct_error(self, x, z): + """Cross Entropy in the language case + Args: + x: tuple which contains x_ and sents_len + x_: (batch_size, seq_len) + sents_len: long tensor of sentence lengths + z: (batch_size, n_sample, nz) + Returns: + loss: (batch_size, n_sample). Loss + across different sentence and z + """ + + x, sents_len = x + + #remove end symbol + src = x[:, :-1] + + # remove start symbol + tgt = x[:, 1:] + + batch_size, seq_len = src.size() + n_sample = z.size(1) + + # (batch_size * n_sample, seq_len, vocab_size) + output_logits = self.decode((src, sents_len), z) + + if n_sample == 1: + tgt = tgt.contiguous().view(-1) + else: + # (batch_size * n_sample * seq_len) + tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ + .contiguous().view(-1) + + # (batch_size * n_sample * seq_len) + loss = self.loss(output_logits.view(-1, output_logits.size(2)), + tgt) + + + # (batch_size, n_sample) + return loss.view(batch_size, n_sample, -1).sum(-1) + diff --git a/modules/decoders/dec_pixelcnn.py b/modules/decoders/dec_pixelcnn.py new file mode 100755 index 0000000..d69ac66 --- /dev/null +++ b/modules/decoders/dec_pixelcnn.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from .decoder import DecoderBase + +def he_init(m): + s = np.sqrt(2./ m.in_features) + m.weight.data.normal_(0, s) + +class GatedMaskedConv2d(nn.Module): + def __init__(self, in_dim, out_dim=None, kernel_size = 3, mask = 'B'): + super(GatedMaskedConv2d, self).__init__() + if out_dim is None: + out_dim = in_dim + self.dim = out_dim + self.size = kernel_size + self.mask = mask + pad = self.size // 2 + + #vertical stack + self.v_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(pad+1, self.size)) + self.v_pad1 = nn.ConstantPad2d((pad, pad, pad, 0), 0) + self.v_pad2 = nn.ConstantPad2d((0, 0, 1, 0), 0) + self.vh_conv = nn.Conv2d(2*self.dim, 2*self.dim, kernel_size = 1) + + #horizontal stack + self.h_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(1, pad+1)) + self.h_pad1 = nn.ConstantPad2d((self.size // 2, 0, 0, 0), 0) + self.h_pad2 = nn.ConstantPad2d((1, 0, 0, 0), 0) + self.h_conv_res = nn.Conv2d(self.dim, self.dim, 1) + + def forward(self, v_map, h_map): + v_out = self.v_pad2(self.v_conv(self.v_pad1(v_map)))[:, :, :-1, :] + v_map_out = F.tanh(v_out[:, :self.dim])*F.sigmoid(v_out[:, self.dim:]) + vh = self.vh_conv(v_out) + + h_out = self.h_conv(self.h_pad1(h_map)) + if self.mask == 'A': + h_out = self.h_pad2(h_out)[:, :, :, :-1] + h_out = h_out + vh + h_out = F.tanh(h_out[:, :self.dim])*F.sigmoid(h_out[:, self.dim:]) + h_map_out = self.h_conv_res(h_out) + if self.mask == 'B': + h_map_out = h_map_out + h_map + return v_map_out, h_map_out + +class StackedGatedMaskedConv2d(nn.Module): + def __init__(self, + img_size = [1, 28, 28], layers = [64,64,64], + kernel_size = [7,7,7], latent_dim=64, latent_feature_map = 1): + super(StackedGatedMaskedConv2d, self).__init__() + input_dim = img_size[0] + self.conv_layers = [] + if latent_feature_map > 0: + self.latent_feature_map = latent_feature_map + self.z_linear = nn.Linear(latent_dim, latent_feature_map*28*28) + for i in range(len(kernel_size)): + if i == 0: + self.conv_layers.append(GatedMaskedConv2d(input_dim+latent_feature_map, + layers[i], kernel_size[i], 'A')) + else: + self.conv_layers.append(GatedMaskedConv2d(layers[i-1], layers[i], kernel_size[i])) + + self.modules = nn.ModuleList(self.conv_layers) + + def forward(self, img, q_z=None): + """ + Args: + img: (batch, nc, H, W) + q_z: (batch, nsamples, nz) + """ + + batch_size, nsamples, _ = q_z.size() + if q_z is not None: + z_img = self.z_linear(q_z) + z_img = z_img.view(img.size(0), nsamples, self.latent_feature_map, img.size(2), img.size(3)) + + # (batch, nsamples, nc, H, W) + img = img.unsqueeze(1).expand(batch_size, nsamples, *img.size()[1:]) + + for i in range(len(self.conv_layers)): + if i == 0: + if q_z is not None: + # (batch, nsamples, nc + fm, H, W) --> (batch * nsamples, nc + fm, H, W) + v_map = torch.cat([img, z_img], 2) + v_map = v_map.view(-1, *v_map.size()[2:]) + else: + v_map = img + h_map = v_map + v_map, h_map = self.conv_layers[i](v_map, h_map) + return h_map + +class PixelCNNDecoder(DecoderBase): + """docstring for PixelCNNDecoder""" + def __init__(self, args): + super(PixelCNNDecoder, self).__init__() + self.dec_cnn = StackedGatedMaskedConv2d(img_size=args.img_size, layers = args.dec_layers, + latent_dim= args.nz, kernel_size = args.dec_kernel_size, + latent_feature_map = args.latent_feature_map) + + self.dec_linear = nn.Conv2d(args.dec_layers[-1], args.img_size[0], kernel_size = 1) + + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + he_init(m) + + def decode(self, img, q_z): + dec_cnn_output = self.dec_cnn(img, q_z) + pred = F.sigmoid(self.dec_linear(dec_cnn_output)) + return pred + + def reconstruct_error(self, x, z): + """Cross Entropy in the language case + Args: + x: (batch_size, nc, H, W) + z: (batch_size, n_sample, nz) + Returns: + loss: (batch_size, n_sample). Loss + across different sentence and z + """ + + batch_size, nsamples, _ = z.size() + # (batch * nsamples, nc, H, W) + pred = self.decode(x, z) + prob = torch.clamp(pred.view(pred.size(0), -1), min=1e-5, max=1.-1e-5) + + # (batch, nsamples, nc, H, W) --> (batch * nsamples, nc, H, W) + x = x.unsqueeze(1).expand(batch_size, nsamples, *x.size()[1:]).contiguous() + tgt_vec = x.view(-1, *x.size()[2:]) + + # (batch * nsamples, *) + tgt_vec = tgt_vec.view(tgt_vec.size(0), -1) + + log_bernoulli = tgt_vec * torch.log(prob) + (1. - tgt_vec)*torch.log(1. - prob) + + log_bernoulli = log_bernoulli.view(batch_size, nsamples, -1) + + return -torch.sum(log_bernoulli, 2) + + + def log_probability(self, x, z): + """Cross Entropy in the language case + Args: + x: (batch_size, nc, H, W) + z: (batch_size, n_sample, nz) + Returns: + log_p: (batch_size, n_sample). + log_p(x|z) across different x and z + """ + + return -self.reconstruct_error(x, z) diff --git a/modules/decoders/dec_pixelcnn_v2.py b/modules/decoders/dec_pixelcnn_v2.py new file mode 100755 index 0000000..6987511 --- /dev/null +++ b/modules/decoders/dec_pixelcnn_v2.py @@ -0,0 +1,232 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +import numpy as np + +from .decoder import DecoderBase + +class MaskedConv2d(nn.Conv2d): + def __init__(self, mask_type, masked_channels, *args, **kwargs): + super(MaskedConv2d, self).__init__(*args, **kwargs) + assert mask_type in {'A', 'B'} + self.register_buffer('mask', self.weight.data.clone()) + _, _, kH, kW = self.weight.size() + self.mask.fill_(1) + self.mask[:, :masked_channels, kH // 2, kW // 2 + (mask_type == 'B'):] = 0 + self.mask[:, :masked_channels, kH // 2 + 1:] = 0 + + def reset_parameters(self): + n = self.kernel_size[0] * self.kernel_size[1] * self.out_channels + self.weight.data.normal_(0, math.sqrt(2. / n)) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x): + self.weight.data.mul_(self.mask) + return super(MaskedConv2d, self).forward(x) + +class PixelCNNBlock(nn.Module): + def __init__(self, in_channels, kernel_size): + super(PixelCNNBlock, self).__init__() + self.mask_type = 'B' + padding = kernel_size // 2 + out_channels = in_channels // 2 + + self.main = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ELU(), + MaskedConv2d(self.mask_type, out_channels, out_channels, out_channels, kernel_size, padding=padding, bias=False), + nn.BatchNorm2d(out_channels), + nn.ELU(), + nn.Conv2d(out_channels, in_channels, 1, bias=False), + nn.BatchNorm2d(in_channels), + ) + self.activation = nn.ELU() + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, input): + return self.activation(self.main(input) + input) + + +class MaskABlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, masked_channels): + super(MaskABlock, self).__init__() + self.mask_type = 'A' + padding = kernel_size // 2 + + self.main = nn.Sequential( + MaskedConv2d(self.mask_type, masked_channels, in_channels, out_channels, kernel_size, padding=padding, bias=False), + nn.BatchNorm2d(out_channels), + nn.ELU(), + ) + self.reset_parameters() + + def reset_parameters(self): + m = self.main[1] + assert isinstance(m, nn.BatchNorm2d) + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, input): + return self.main(input) + + +class PixelCNN(nn.Module): + def __init__(self, in_channels, out_channels, num_blocks, kernel_sizes, masked_channels): + super(PixelCNN, self).__init__() + assert num_blocks == len(kernel_sizes) + self.blocks = [] + for i in range(num_blocks): + if i == 0: + block = MaskABlock(in_channels, out_channels, kernel_sizes[i], masked_channels) + else: + block = PixelCNNBlock(out_channels, kernel_sizes[i]) + self.blocks.append(block) + + self.main = nn.ModuleList(self.blocks) + + self.direct_connects = [] + for i in range(1, num_blocks - 1): + self.direct_connects.append(PixelCNNBlock(out_channels, kernel_sizes[i])) + + self.direct_connects = nn.ModuleList(self.direct_connects) + + def forward(self, input): + # [batch, out_channels, H, W] + direct_inputs = [] + for i, layer in enumerate(self.main): + if i > 2: + direct_input = direct_inputs.pop(0) + direct_conncet = self.direct_connects[i - 3] + input = input + direct_conncet(direct_input) + + input = layer(input) + direct_inputs.append(input) + assert len(direct_inputs) == 3, 'architecture error: %d' % len(direct_inputs) + direct_conncet = self.direct_connects[-1] + return input + direct_conncet(direct_inputs.pop(0)) + +class PixelCNNDecoderV2(DecoderBase): + def __init__(self, args, ngpu=1, mode='large'): + super(PixelCNNDecoderV2, self).__init__() + self.ngpu = ngpu + self.nz = args.nz + self.nc = 1 + # self.fm_latent = 4 old ??? hardcoded + self.fm_latent = args.latent_feature_map + + self.img_latent = 28 * 28 * self.fm_latent + if self.nz != 0 : + self.z_transform = nn.Sequential( + nn.Linear(self.nz, self.img_latent), + ) + if mode == 'small': + kernal_sizes = [7, 7, 7, 5, 5, 3, 3] + elif mode == 'large': + kernal_sizes = [7, 7, 7, 7, 7, 5, 5, 5, 5, 3, 3, 3, 3] + else: + raise ValueError('unknown mode: %s' % mode) + + hidden_channels = 64 + self.main = nn.Sequential( + PixelCNN(self.nc + self.fm_latent, hidden_channels, len(kernal_sizes), kernal_sizes, self.nc), + nn.Conv2d(hidden_channels, hidden_channels, 1, bias=False), + nn.BatchNorm2d(hidden_channels), + nn.ELU(), + nn.Conv2d(hidden_channels, self.nc, 1, bias=False), + nn.Sigmoid(), + ) + self.reset_parameters() + + def reset_parameters(self): + if self.nz != 0: + nn.init.xavier_uniform_(self.z_transform[0].weight) + nn.init.constant_(self.z_transform[0].bias, 0) + + m = self.main[2] + assert isinstance(m, nn.BatchNorm2d) + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + return output + + def reconstruct_error(self, x, z): + eps = 1e-12 + if type(z) == type(None): + batch_size, nsampels, _, _ = x.size() + img = x.unsqueeze(1).expand(batch_size, nsampels, *x.size()[1:]) + else: + batch_size, nsampels, nz = z.size() + # [batch, nsamples, -1] --> [batch, nsamples, fm, H, W] + z = self.z_transform(z).view(batch_size, nsampels, self.fm_latent, 28, 28) + + # [batch, nc, H, W] --> [batch, 1, nc, H, W] --> [batch, nsample, nc, H, W] + img = x.unsqueeze(1).expand(batch_size, nsampels, *x.size()[1:]) + # [batch, nsample, nc+fm, H, W] --> [batch * nsamples, nc+fm, H, W] + img = torch.cat([img, z], dim=2) + + img = img.view(-1, *img.size()[2:]) + + # [batch * nsamples, *] --> [batch, nsamples, -1] + recon_x = self.forward(img).view(batch_size, nsampels, -1) + # [batch, -1] + x_flat = x.view(batch_size, -1) + BCE = (recon_x + eps).log() * x_flat.unsqueeze(1) + (1.0 - recon_x + eps).log() * (1. - x_flat).unsqueeze(1) #cross-entropy + # [batch, nsamples] + return BCE.sum(dim=2) * -1.0 + + def log_probability(self, x, z): + bce = self.reconstruct_error(x, z) + return bce * -1. + + def decode(self, z, deterministic): + ''' + + Args: + z: Tensor + the tensor of latent z shape=[batch, nz] + deterministic: boolean + randomly sample of decode via argmaximizing probability + + Returns: Tensor + the tensor of decoded x shape=[batch, *] + + ''' + H = W = 28 + batch_size, nz = z.size() + + # [batch, -1] --> [batch, fm, H, W] + z = self.z_transform(z).view(batch_size, self.fm_latent, H, W) + img = Variable(z.data.new(batch_size, self.nc, H, W).zero_(), volatile=True) + # [batch, nc+fm, H, W] + img = torch.cat([img, z], dim=1) + for i in range(H): + for j in range(W): + # [batch, nc, H, W] + recon_img = self.forward(img) + # [batch, nc] + img[:, :self.nc, i, j] = torch.ge(recon_img[:, :, i, j], 0.5).float() if deterministic else torch.bernoulli(recon_img[:, :, i, j]) + # img[:, :self.nc, i, j] = torch.bernoulli(recon_img[:, :, i, j]) + + # [batch, nc, H, W] + img_probs = self.forward(img) + return img[:, :self.nc], img_probs diff --git a/modules/decoders/decoder.py b/modules/decoders/decoder.py new file mode 100755 index 0000000..a419661 --- /dev/null +++ b/modules/decoders/decoder.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + + +class DecoderBase(nn.Module): + """docstring for Decoder""" + def __init__(self): + super(DecoderBase, self).__init__() + + def decode(self, x, z): + + raise NotImplementedError + + def reconstruct_error(self, x, z): + """reconstruction loss + Args: + x: (batch_size, *) + z: (batch_size, n_sample, nz) + Returns: + loss: (batch_size, n_sample). Loss + across different sentence and z + """ + + raise NotImplementedError + + def beam_search_decode(self, z, K): + """beam search decoding + Args: + z: (batch_size, nz) + K: the beam size + + Returns: List1 + List1: the decoded word sentence list + """ + + raise NotImplementedError + + def sample_decode(self, z): + """sampling from z + Args: + z: (batch_size, nz) + + Returns: List1 + List1: the decoded word sentence list + """ + + raise NotImplementedError + + def greedy_decode(self, z): + """greedy decoding from z + Args: + z: (batch_size, nz) + + Returns: List1 + List1: the decoded word sentence list + """ + + raise NotImplementedError + + def log_probability(self, x, z): + """ + Args: + x: (batch_size, *) + z: (batch_size, n_sample, nz) + Returns: + log_p: (batch_size, n_sample). + log_p(x|z) across different x and z + """ + + raise NotImplementedError + + + + \ No newline at end of file diff --git a/modules/decoders/decoder_helper.py b/modules/decoders/decoder_helper.py new file mode 100755 index 0000000..6b0bf31 --- /dev/null +++ b/modules/decoders/decoder_helper.py @@ -0,0 +1,23 @@ + + +class BeamSearchNode(object): + def __init__(self, hiddenstate, previousNode, wordId, logProb, length): + ''' + :param hiddenstate: + :param previousNode: + :param wordId: + :param logProb: + :param length: + ''' + self.h = hiddenstate + self.prevNode = previousNode + self.wordid = wordId + self.logp = logProb + self.leng = length + + def eval(self, alpha=1.0): + reward = 0 + # Add here a function for shaping a reward + + return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward + diff --git a/modules/discriminators/__init__.py b/modules/discriminators/__init__.py new file mode 100755 index 0000000..fe33551 --- /dev/null +++ b/modules/discriminators/__init__.py @@ -0,0 +1 @@ +from .discriminator_linear import * \ No newline at end of file diff --git a/modules/discriminators/discriminator_linear.py b/modules/discriminators/discriminator_linear.py new file mode 100755 index 0000000..139b73a --- /dev/null +++ b/modules/discriminators/discriminator_linear.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence + + +class LinearDiscriminator_only(nn.Module): + """docstring for LinearDiscriminator""" + + def __init__(self, args, ncluster): + super(LinearDiscriminator_only, self).__init__() + self.args = args + if args.IAF: + self.linear = nn.Linear(args.nz, ncluster) + else: + self.linear = nn.Linear(args.nz, ncluster) + self.loss = nn.CrossEntropyLoss(reduction="none") + + + def get_performance_with_feature(self, batch_data, batch_labels): + mu = batch_data + logits = self.linear(mu) + loss = self.loss(logits, batch_labels) + + _, pred = torch.max(logits, dim=1) + correct = torch.eq(pred, batch_labels).float().sum().item() + + return loss, correct diff --git a/modules/encoders/.DS_Store b/modules/encoders/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/modules/encoders/.DS_Store differ diff --git a/modules/encoders/__init__.py b/modules/encoders/__init__.py new file mode 100755 index 0000000..4c78073 --- /dev/null +++ b/modules/encoders/__init__.py @@ -0,0 +1,5 @@ +from .enc_lstm import * +from .enc_resnet_v2 import * +from .flow import * +from .enc_flow import * + diff --git a/modules/encoders/enc_flow.py b/modules/encoders/enc_flow.py new file mode 100644 index 0000000..21c3382 --- /dev/null +++ b/modules/encoders/enc_flow.py @@ -0,0 +1,416 @@ +from .flow import * +import torch +from torch import nn +import numpy as np +import math +from ..utils import log_sum_exp + + +class IAFEncoderBase(nn.Module): + """docstring for EncoderBase""" + + def __init__(self): + super(IAFEncoderBase, self).__init__() + + def sample(self, input, nsamples): + """sampling from the encoder + Returns: Tensor1, Tuple + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tuple: contains the tensor mu [batch, nz] and + logvar[batch, nz] + """ + + z_T, log_q_z = self.forward(input, nsamples) + + return z_T, log_q_z + + def forward(self, x, n_sample): + """ + Args: + x: (batch_size, *) + + Returns: Tensor1, Tensor2 + Tensor1: the mean tensor, shape (batch, nz) + Tensor2: the logvar tensor, shape (batch, nz) + """ + + raise NotImplementedError + + def encode(self, input, args): + """perform the encoding and compute the KL term + + Returns: Tensor1, Tensor2 + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tensor2: the tenor of KL for each x with shape [batch] + + """ + + # (batch, nsamples, nz) + z_T, log_q_z = self.forward(input, args.nsamples) + + log_p_z = self.log_q_z_0(z=z_T) # [b s nz] + + kl = log_q_z - log_p_z + + # free-bit + if self.training and args.fb == 1 and args.target_kl > 0: + kl_obj = torch.mean(kl, dim=[0, 1], keepdim=True) + kl_obj = torch.clamp_min(kl_obj, args.target_kl) + kl_obj = kl_obj.expand(kl.size(0), kl.size(1), -1) + kl = kl_obj + + return z_T, kl.sum(dim=[1, 2]) # like KL + + def reparameterize(self, mu, logvar, nsamples=1): + """sample from posterior Gaussian family + Args: + mu: Tensor + Mean of gaussian distribution with shape (batch, nz) + + logvar: Tensor + logvar of gaussian distibution with shape (batch, nz) + + Returns: Tensor + Sampled z with shape (batch, nsamples, nz) + """ + # import ipdb + # ipdb.set_trace() + batch_size, nz = mu.size() + std = logvar.mul(0.5).exp() + + mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) + std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) + + eps = torch.zeros_like(std_expd).normal_() + + return mu_expd + torch.mul(eps, std_expd) + + def eval_inference_dist(self, x, z, param=None): + """this function computes log q(z | x) + Args: + z: tensor + different z points that will be evaluated, with + shape [batch, nsamples, nz] + Returns: Tensor1 + Tensor1: log q(z|x) with shape [batch, nsamples] + """ + + nz = z.size(2) + + if not param: + mu, logvar = self.forward(x) + else: + mu, logvar = param + + # if self.args.gamma <0: + # mu,logvar = self.trans_param(mu,logvar) + + # import ipdb + # ipdb.set_trace() + + # (batch_size, 1, nz) + mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1) + var = logvar.exp() + + # (batch_size, nsamples, nz) + dev = z - mu + + # (batch_size, nsamples) + log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ + 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) + + return log_density + + +class VariationalFlow(IAFEncoderBase): + """Approximate posterior parameterized by a flow (https://arxiv.org/abs/1606.04934).""" + + def __init__(self, args, vocab_size, model_init, emb_init): + super().__init__() + + self.ni = args.ni + self.nh = args.enc_nh + self.nz = args.nz + self.args = args + + flow_depth = args.flow_depth + flow_width = args.flow_width + + self.embed = nn.Embedding(vocab_size, args.ni) + self.lstm = nn.LSTM(input_size=args.ni, hidden_size=args.enc_nh, num_layers=1, + batch_first=True, dropout=0) + + self.linear = nn.Linear(args.enc_nh, 4 * args.nz, bias=False) + modules = [] + for _ in range(flow_depth): + modules.append(InverseAutoregressiveFlow(num_input=args.nz, + num_hidden=flow_width * args.nz, # hidden dim in MADE + num_context=2 * args.nz)) + modules.append(Reverse(args.nz)) + + self.q_z_flow = FlowSequential(*modules) + self.log_q_z_0 = NormalLogProb() + self.softplus = nn.Softplus() + self.reset_parameters(model_init, emb_init) + + self.BN = False + if self.args.gamma > 0: + self.BN = True + self.mu_bn = nn.BatchNorm1d(args.nz, eps=1e-8) + self.gamma = args.gamma + nn.init.constant_(self.mu_bn.weight, self.args.gamma) + nn.init.constant_(self.mu_bn.bias, 0.0) + + self.DP = False + if self.args.p_drop > 0 and self.args.delta_rate > 0: + self.DP = True + self.p_drop = self.args.p_drop + self.delta_rate = self.args.delta_rate + + def reset_parameters(self, model_init, emb_init): + for name, param in self.lstm.named_parameters(): + # self.initializer(param) + if 'bias' in name: + nn.init.constant_(param, 0.0) + # model_init(param) + elif 'weight' in name: + model_init(param) + + model_init(self.linear.weight) + emb_init(self.embed.weight) + + def forward(self, input, n_samples): + """Return sample of latent variable and log prob.""" + word_embed = self.embed(input) + _, (last_state, last_cell) = self.lstm(word_embed) + loc_scale, h = self.linear(last_state.squeeze(0)).chunk(2, -1) + loc, scale_arg = loc_scale.chunk(2, -1) + scale = self.softplus(scale_arg) + + if self.BN: + ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5 + #if ss < self.gamma: + self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss + loc = self.mu_bn(loc) + if self.DP and self.args.kl_weight >= self.args.drop_start: + var = scale ** 2 + var = torch.dropout(var, p=self.p_drop, train=self.training) + var += self.delta_rate * 1.0 / (2 * math.e * math.pi) + scale = var ** 0.5 + + loc = loc.unsqueeze(1) + scale = scale.unsqueeze(1) + h = h.unsqueeze(1) + + eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) + z_0 = loc + scale * eps # reparameterization + log_q_z_0 = self.log_q_z_0(loc=loc, scale=scale, z=z_0) + z_T, log_q_z_flow = self.q_z_flow(z_0, context=h) + log_q_z = (log_q_z_0 + log_q_z_flow) # [b s nz] + + if torch.sum(torch.isnan(z_T)): + import ipdb + ipdb.set_trace() + + ################ + if torch.rand(1).sum() <= 0.0005: + if self.BN: + self.mu_bn.weight + + return z_T, log_q_z + # return z_0, log_q_z_0.sum(-1) + + def infer_param(self, input): + word_embed = self.embed(input) + _, (last_state, last_cell) = self.lstm(word_embed) + loc_scale, h = self.linear(last_state.squeeze(0)).chunk(2, -1) + loc, scale_arg = loc_scale.chunk(2, -1) + scale = self.softplus(scale_arg) + # logvar = scale_arg + + if self.BN: + ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5 + if ss < self.gamma: + self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss + loc = self.mu_bn(loc) + if self.DP and self.args.kl_weight >= self.args.drop_start: + var = scale ** 2 + var = torch.dropout(var, p=self.p_drop, train=self.training) + var += self.delta_rate * 1.0 / (2 * math.e * math.pi) + scale = var ** 0.5 + + return loc, torch.log(scale ** 2) + + def learn_feature(self, input): + word_embed = self.embed(input) + _, (last_state, last_cell) = self.lstm(word_embed) + loc_scale, h = self.linear(last_state.squeeze(0)).chunk(2, -1) + loc, scale_arg = loc_scale.chunk(2, -1) + import ipdb + ipdb.set_trace() + if self.BN: + loc = self.mu_bn(loc) + loc = loc.unsqueeze(1) + h = h.unsqueeze(1) + z_T, log_q_z_flow = self.q_z_flow(loc, context=h) + return loc, z_T + + +from .enc_resnet_v2 import ResNet + + +class FlowResNetEncoderV2(IAFEncoderBase): + def __init__(self, args, ngpu=1): + super(FlowResNetEncoderV2, self).__init__() + self.ngpu = ngpu + self.nz = args.nz + self.nc = 1 + hidden_units = 512 + self.main = nn.Sequential( + ResNet(self.nc, [64, 64, 64], [2, 2, 2]), + nn.Conv2d(64, hidden_units, 4, 1, 0, bias=False), + nn.BatchNorm2d(hidden_units), + nn.ELU(), + ) + self.linear = nn.Linear(hidden_units, 4 * self.nz) + self.reset_parameters() + self.delta_rate = args.delta_rate + + self.args = args + + flow_depth = args.flow_depth + flow_width = args.flow_width + + modules = [] + for _ in range(flow_depth): + modules.append(InverseAutoregressiveFlow(num_input=args.nz, + num_hidden=flow_width * args.nz, # hidden dim in MADE + num_context=2 * args.nz)) + modules.append(Reverse(args.nz)) + + self.q_z_flow = FlowSequential(*modules) + self.log_q_z_0 = NormalLogProb() + self.softplus = nn.Softplus() + + self.BN = False + if self.args.gamma > 0: + self.BN = True + self.mu_bn = nn.BatchNorm1d(args.nz, eps=1e-8) + self.gamma = args.gamma + nn.init.constant_(self.mu_bn.weight, self.args.gamma) + nn.init.constant_(self.mu_bn.bias, 0.0) + + self.DP = False + if self.args.p_drop > 0 and self.args.delta_rate > 0: + self.DP = True + self.p_drop = self.args.p_drop + self.delta_rate = self.args.delta_rate + + def reset_parameters(self): + for m in self.main.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + nn.init.xavier_uniform_(self.linear.weight) + nn.init.constant_(self.linear.bias, 0.0) + + def forward(self, input, n_samples): + if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + output = self.linear(output.view(output.size()[:2])) + loc_scale, h = output.chunk(2, -1) + loc, scale_arg = loc_scale.chunk(2, -1) + scale = self.softplus(scale_arg) + + if self.BN: + ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5 + #if ss < self.gamma: + self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss + loc = self.mu_bn(loc) + + if self.DP and self.args.kl_weight >= self.args.drop_start: + var = scale ** 2 + var = torch.dropout(var, p=self.p_drop, train=self.training) + var += self.delta_rate * 1.0 / (2 * math.e * math.pi) + scale = var ** 0.5 + + loc = loc.unsqueeze(1) + scale = scale.unsqueeze(1) + h = h.unsqueeze(1) + + eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) + z_0 = loc + scale * eps # reparameterization + log_q_z_0 = self.log_q_z_0(loc=loc, scale=scale, z=z_0) + z_T, log_q_z_flow = self.q_z_flow(z_0, context=h) + log_q_z = (log_q_z_0 + log_q_z_flow) # [b s nz] + + if torch.sum(torch.isnan(z_T)): + import ipdb + ipdb.set_trace() + + if torch.rand(1).sum() <= 0.001: + if self.BN: + self.mu_bn.weight + + return z_T, log_q_z + + + def infer_param(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + output = self.linear(output.view(output.size()[:2])) + loc_scale, h = output.chunk(2, -1) + loc, scale_arg = loc_scale.chunk(2, -1) + scale = self.softplus(scale_arg) + + if self.BN: + ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5 + if ss < self.gamma: + self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss + loc = self.mu_bn(loc) + if self.DP and self.args.kl_weight >= self.args.drop_start: + var = scale ** 2 + var = torch.dropout(var, p=self.p_drop, train=self.training) + var += self.delta_rate * 1.0 / (2 * math.e * math.pi) + scale = var ** 0.5 + + return loc, torch.log(scale ** 2) + + def learn_feature(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + output = self.linear(output.view(output.size()[:2])) + loc_scale, h = output.chunk(2, -1) + loc, _ = loc_scale.chunk(2, -1) + if self.BN: + ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5 + if ss < self.gamma: + self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss + loc = self.mu_bn(loc) + loc = loc.unsqueeze(1) + h = h.unsqueeze(1) + z_T, log_q_z_flow = self.q_z_flow(loc, context=h) + return loc, z_T + + +class NormalLogProb(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z, loc=None, scale=None): + if loc is None: + loc = torch.zeros_like(z, device=z.device) + if scale is None: + scale = torch.ones_like(z, device=z.device) + + var = torch.pow(scale, 2) + return -0.5 * torch.log(2 * np.pi * var) - torch.pow(z - loc, 2) / (2 * var) diff --git a/modules/encoders/enc_lstm.py b/modules/encoders/enc_lstm.py new file mode 100644 index 0000000..70dba5c --- /dev/null +++ b/modules/encoders/enc_lstm.py @@ -0,0 +1,120 @@ +import math +import torch +import torch.nn as nn +from .encoder import GaussianEncoderBase + + +class LSTMEncoder(GaussianEncoderBase): + """Gaussian LSTM Encoder with constant-length batching""" + + def __init__(self, args, vocab_size, model_init, emb_init): + super(LSTMEncoder, self).__init__() + self.ni = args.ni + self.nh = args.enc_nh + self.nz = args.nz + self.args = args + self.embed = nn.Embedding(vocab_size, args.ni) + + self.lstm = nn.LSTM(input_size=args.ni, + hidden_size=args.enc_nh, + num_layers=1, + batch_first=True, + dropout=0) + + # dimension transformation to z (mean and logvar) + self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False) + + self.reset_parameters(model_init, emb_init) + self.delta_rate = args.delta_rate + + def reset_parameters(self, model_init, emb_init): + for param in self.parameters(): + model_init(param) + emb_init(self.embed.weight) + + def forward(self, input): + """ + Args: + x: (batch_size, seq_len) + + Returns: Tensor1, Tensor2 + Tensor1: the mean tensor, shape (batch, nz) + Tensor2: the logvar tensor, shape (batch, nz) + """ + word_embed = self.embed(input) + _, (last_state, last_cell) = self.lstm(word_embed) + mean, logvar = self.linear(last_state).chunk(2, -1) + logvar = torch.log(torch.exp(logvar) + self.delta_rate * 1.0 / (2 * math.e * math.pi)) + + return mean.squeeze(0), logvar.squeeze(0) + + +class GaussianLSTMEncoder(GaussianEncoderBase): + """Gaussian LSTM Encoder with constant-length input""" + + def __init__(self, args, vocab_size, model_init, emb_init): + super(GaussianLSTMEncoder, self).__init__() + self.ni = args.ni + self.nh = args.enc_nh + self.nz = args.nz + self.args = args + + self.embed = nn.Embedding(vocab_size, args.ni) + + self.lstm = nn.LSTM(input_size=args.ni, + hidden_size=args.enc_nh, + num_layers=1, + batch_first=True, + dropout=0) + + self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False) + self.mu_bn = nn.BatchNorm1d(args.nz) + self.gamma = args.gamma + + self.reset_parameters(model_init, emb_init) + self.delta_rate = args.delta_rate + + def reset_parameters(self, model_init, emb_init, reset=False): + if not reset: + nn.init.constant_(self.mu_bn.weight, self.args.gamma) + else: + print('reset bn!') + if self.args.gamma_train: + nn.init.constant_(self.mu_bn.weight, self.args.gamma) + else: + self.mu_bn.weight.fill_(self.args.gamma) + nn.init.constant_(self.mu_bn.bias, 0.0) + + def forward(self, input): + """ + Args: + x: (batch_size, seq_len) + + Returns: Tensor1, Tensor2 + Tensor1: the mean tensor, shape (batch, nz) + Tensor2: the logvar tensor, shape (batch, nz) + """ + + # (batch_size, seq_len-1, args.ni) + word_embed = self.embed(input) + + _, (last_state, last_cell) = self.lstm(word_embed) + mean, logvar = self.linear(last_state).chunk(2, -1) + if self.args.gamma <= 0 or (mean.squeeze(0).size(0) == 1 and self.training == True): + mean = mean.squeeze(0) + else: + self.mu_bn.weight.requires_grad = True + ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5 + #if ss < self.gamma: + self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss + mean = self.mu_bn(mean.squeeze(0)) + + if torch.sum(torch.isnan(mean)) or torch.sum(torch.isnan(logvar)): + import ipdb + ipdb.set_trace() + + if self.args.kl_weight == 1: + logvar = torch.log(torch.exp(logvar) + self.delta_rate * 1.0 / (2 * math.e * math.pi)) + + return mean, logvar.squeeze(0) + diff --git a/modules/encoders/enc_resnet_v2.py b/modules/encoders/enc_resnet_v2.py new file mode 100755 index 0000000..d4a1bab --- /dev/null +++ b/modules/encoders/enc_resnet_v2.py @@ -0,0 +1,203 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from sympy import * +from .encoder import GaussianEncoderBase + +""" +A better ResNet baseline +""" + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, + kernel_size=3, stride=stride, padding=1, bias=False) + + +def deconv3x3(in_planes, out_planes, stride=1, output_padding=0): + "3x3 deconvolution with padding" + return nn.ConvTranspose2d(in_planes, out_planes, + kernel_size=3, stride=stride, padding=1, + output_padding=output_padding, bias=False) + + +class ResNetBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(ResNetBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.activation = nn.ELU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + downsample = None + if stride != 1 or inplanes != planes: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes), + ) + self.downsample = downsample + self.stride = stride + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + # [batch, planes, ceil(h/stride), ceil(w/stride)] + residual = x if self.downsample is None else self.downsample(x) + + # [batch, planes, ceil(h/stride), ceil(w/stride)] + out = self.conv1(x) + out = self.bn1(out) + out = self.activation(out) + + # [batch, planes, ceil(h/stride), ceil(w/stride)] + out = self.conv2(out) + out = self.bn2(out) + + out = self.activation(out + residual) + + # [batch, planes, ceil(h/stride), ceil(w/stride)] + return out + + +class ResNet(nn.Module): + def __init__(self, inplanes, planes, strides): + super(ResNet, self).__init__() + assert len(planes) == len(strides) + + blocks = [] + for i in range(len(planes)): + plane = planes[i] + stride = strides[i] + block = ResNetBlock(inplanes, plane, stride=stride) + blocks.append(block) + inplanes = plane + + self.main = nn.Sequential(*blocks) + + def forward(self, x): + return self.main(x) + + +class ResNetEncoderV2(GaussianEncoderBase): + def __init__(self, args, ngpu=1): + super(ResNetEncoderV2, self).__init__() + self.ngpu = ngpu + self.nz = args.nz + self.nc = 1 + hidden_units = 512 + self.main = nn.Sequential( + ResNet(self.nc, [64, 64, 64], [2, 2, 2]), + nn.Conv2d(64, hidden_units, 4, 1, 0, bias=False), + nn.BatchNorm2d(hidden_units), + nn.ELU(), + ) + self.linear = nn.Linear(hidden_units, 2 * self.nz) + self.reset_parameters() + self.delta_rate = args.delta_rate + + self.args = args + + def reset_parameters(self): + for m in self.main.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + nn.init.xavier_uniform_(self.linear.weight) + nn.init.constant_(self.linear.bias, 0.0) + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + output = self.linear(output.view(output.size()[:2])) + mu, logvar = output.chunk(2, 1) + + return mu, logvar + + +class BNResNetEncoderV2(GaussianEncoderBase): + def __init__(self, args, ngpu=1): + super(BNResNetEncoderV2, self).__init__() + self.ngpu = ngpu + self.nz = args.nz + self.nc = 1 + hidden_units = 512 + self.main = nn.Sequential( + ResNet(self.nc, [64, 64, 64], [2, 2, 2]), + nn.Conv2d(64, hidden_units, 4, 1, 0, bias=False), + nn.BatchNorm2d(hidden_units), + nn.ELU(), + ) + self.linear = nn.Linear(hidden_units, 2 * self.nz) + self.mu_bn = nn.BatchNorm1d(args.nz) + self.gamma = args.gamma + self.args = args + + self.reset_parameters() + self.delta_rate = args.delta_rate + + def reset_parameters(self, reset=False): + for m in self.main.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + nn.init.xavier_uniform_(self.linear.weight) + nn.init.constant_(self.linear.bias, 0.0) + + if not reset: + nn.init.constant_(self.mu_bn.weight, self.gamma) + else: + print('reset bn!') + + nn.init.constant_(self.mu_bn.weight, self.gamma) + nn.init.constant_(self.mu_bn.bias, 0.0) + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + output = self.linear(output.view(output.size()[:2])) + mu, logvar = output.chunk(2, 1) + if self.args.gamma > 0: + self.mu_bn.weight.requires_grad = True + ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5 + #if ss < self.gamma: + self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss + + mu = self.mu_bn(mu.squeeze(0)) + else: + mu = mu.squeeze(0) + + if self.args.kl_weight == 1: + logvar = torch.log(torch.exp(logvar) + self.delta_rate * 1.0 / (2 * math.e * math.pi)) + + if torch.rand(1).sum() <= 0.001: + scale = torch.exp(logvar / 2) + # print('gamma', self.mu_bn.weight) + # print('train loc mean', torch.mean(mu, dim=0)) + # print('train scale std', torch.std(scale, dim=0)) + # print('train scale mean', torch.mean(scale, dim=0)) + + return mu, logvar + diff --git a/modules/encoders/encoder.py b/modules/encoders/encoder.py new file mode 100644 index 0000000..d13cbc2 --- /dev/null +++ b/modules/encoders/encoder.py @@ -0,0 +1,189 @@ +import math +import torch +import torch.nn as nn +import random +from ..utils import log_sum_exp + + +class GaussianEncoderBase(nn.Module): + """docstring for EncoderBase""" + + def __init__(self): + super(GaussianEncoderBase, self).__init__() + + def forward(self, x): + """ + Args: + x: (batch_size, *) + + Returns: Tensor1, Tensor2 + Tensor1: the mean tensor, shape (batch, nz) + Tensor2: the logvar tensor, shape (batch, nz) + """ + + raise NotImplementedError + + def sample(self, input, nsamples): + """sampling from the encoder + Returns: Tensor1, Tuple + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tuple: contains the tensor mu [batch, nz] and + logvar[batch, nz] + """ + + # (batch_size, nz) + mu, logvar = self.forward(input) + # if self.args.gamma<0: + # mu, logvar = self.trans_param(mu, logvar) + + # (batch, nsamples, nz) + z = self.reparameterize(mu, logvar, nsamples) + # if self.args.gamma <0: + # z=self.z_bn(z.squeeze(1)).unsqueeze(1) + + return z, (mu, logvar) + + def encode(self, input, args, training=True): + """perform the encoding and compute the KL term + + Returns: Tensor1, Tensor2 + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tensor2: the tenor of KL for each x with shape [batch] + + """ + nsamples = args.nsamples + # (batch_size, nz) + mu, logvar = self.forward(input) + + if args.p_drop > 0 and training and args.kl_weight == 1: + var = logvar.exp() - args.delta_rate * 1.0 / (2 * math.e * math.pi) + var = torch.dropout(var, p=args.p_drop, train=training) + logvar = torch.log(var + args.delta_rate * 1.0 / (2 * math.e * math.pi)) + + # (batch, nsamples, nz) + z = self.reparameterize(mu, logvar, nsamples) + + KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) + KL = KL.sum(dim=1) #B + + # if torch.sum(torch.isnan(HKL)): + # import ipdb + # ipdb.set_trace() + + return z, KL + + def MPD(self,mu,logvar): + eps = 1e-9 + z_shape = [mu.size(1)] + batch_size = mu.size(0) + # [batch, z_shape] + var = logvar.exp() + # B [batch, batch, z_shape] + B = (mu.unsqueeze(1) - mu.unsqueeze(0)).pow(2).div(var.unsqueeze(0) + eps) + B = B.mean(0).mean(0) #z_shape + A = var.mean(0) + C = (1/(var+eps)).mean(0) + # if torch.max(var)>10: + # import ipdb + # ipdb.set_trace() + return 0.5*(B+A*C-1) + + def CE(self, logvar): + return 0.5*(logvar.mean(dim=0) + math.log(2*math.pi*math.e)) + + def reparameterize(self, mu, logvar, nsamples=1): + """sample from posterior Gaussian family + Args: + mu: Tensor + Mean of gaussian distribution with shape (batch, nz) + + logvar: Tensor + logvar of gaussian distibution with shape (batch, nz) + + Returns: Tensor + Sampled z with shape (batch, nsamples, nz) + """ + + batch_size, nz = mu.size() + std = logvar.mul(0.5).exp() + + mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) + std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) + + eps = torch.zeros_like(std_expd).normal_() + + return mu_expd + torch.mul(eps, std_expd) + + def eval_inference_dist(self, x, z, param=None): + """this function computes log q(z | x) + Args: + z: tensor + different z points that will be evaluated, with + shape [batch, nsamples, nz] + Returns: Tensor1 + Tensor1: log q(z|x) with shape [batch, nsamples] + """ + + nz = z.size(2) + + if not param: + mu, logvar = self.forward(x) + else: + mu, logvar = param + + # (batch_size, 1, nz) + mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1) + var = logvar.exp() + + # (batch_size, nsamples, nz) + dev = z - mu + + # (batch_size, nsamples) + log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ + 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) + + return log_density + + def calc_mi(self, x): + """Approximate the mutual information between x and z + I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) + + Returns: Float + + """ + + # [x_batch, nz] + mu, logvar = self.forward(x) + + # if self.args.gamma<0: + # mu, logvar = self.trans_param( mu, logvar) + + x_batch, nz = mu.size() + + # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) + neg_entropy = (-0.5 * nz * math.log(2 * math.pi) - 0.5 * (1 + logvar).sum(-1)).mean() + + # [z_batch, 1, nz] + z_samples = self.reparameterize(mu, logvar, 1) + + # [1, x_batch, nz] + mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) + var = logvar.exp() + + # (z_batch, x_batch, nz) + dev = z_samples - mu + + # (z_batch, x_batch) + log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ + 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) + + # log q(z): aggregate posterior + # [z_batch] + log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch) + + return (neg_entropy - log_qz.mean(-1)).item() + + + + + diff --git a/modules/encoders/flow.py b/modules/encoders/flow.py new file mode 100755 index 0000000..1b7bb94 --- /dev/null +++ b/modules/encoders/flow.py @@ -0,0 +1,152 @@ +"""Credit: mostly based on Ilya's excellent implementation here: https://github.com/ikostrikov/pytorch-flows""" +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class InverseAutoregressiveFlow(nn.Module): + """Inverse Autoregressive Flows with LSTM-type update. One block. + + Eq 11-14 of https://arxiv.org/abs/1606.04934 + """ + + def __init__(self, num_input, num_hidden, num_context): + super().__init__() + self.made = MADE(num_input=num_input, num_output=num_input * 2, + num_hidden=num_hidden, num_context=num_context) + # init such that sigmoid(s) is close to 1 for stability + self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2) + self.sigmoid = nn.Sigmoid() + self.log_sigmoid = nn.LogSigmoid() + + def forward(self, input, context=None): + m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1) + s = s + self.sigmoid_arg_bias + sigmoid = self.sigmoid(s) + z = sigmoid * input + (1 - sigmoid) * m + return z, -self.log_sigmoid(s) + + +class FlowSequential(nn.Sequential): + """Forward pass.""" + + def forward(self, input, context=None): + total_log_prob = torch.zeros_like(input, device=input.device) + for block in self._modules.values(): + input, log_prob = block(input, context) + total_log_prob += log_prob + return input, total_log_prob + + +class MaskedLinear(nn.Module): + """Linear layer with some input-output connections masked.""" + + def __init__(self, in_features, out_features, mask, context_features=None, bias=True): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + self.register_buffer("mask", mask) + if context_features is not None: + self.cond_linear = nn.Linear(context_features, out_features, bias=False) + + def forward(self, input, context=None): + output = F.linear(input, self.mask * self.linear.weight, self.linear.bias) + if context is None: + return output + else: + return output + self.cond_linear(context) + + +class MADE(nn.Module): + """Implements MADE: Masked Autoencoder for Distribution Estimation. + + Follows https://arxiv.org/abs/1502.03509 + + This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057). + """ + + def __init__(self, num_input, num_output, num_hidden, num_context): + super().__init__() + # m corresponds to m(k), the maximum degree of a node in the MADE paper + self._m = [] + self._masks = [] + self._build_masks(num_input, num_output, num_hidden, num_layers=3) + self._check_masks() + modules = [] + self.input_context_net = MaskedLinear(num_input, num_hidden, self._masks[0], num_context) + modules.append(nn.ReLU()) + modules.append(MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None)) + modules.append(nn.ReLU()) + modules.append(MaskedLinear(num_hidden, num_output, self._masks[2], context_features=None)) + self.net = nn.Sequential(*modules) + + def _build_masks(self, num_input, num_output, num_hidden, num_layers): + """Build the masks according to Eq 12 and 13 in the MADE paper.""" + rng = np.random.RandomState(0) + # assign input units a number between 1 and D + self._m.append(np.arange(1, num_input + 1)) + for i in range(1, num_layers + 1): + # randomly assign maximum number of input nodes to connect to + if i == num_layers: + # assign output layer units a number between 1 and D + m = np.arange(1, num_input + 1) + assert num_output % num_input == 0, "num_output must be multiple of num_input" + self._m.append(np.hstack([m for _ in range(num_output // num_input)])) + else: + # assign hidden layer units a number between 1 and D-1 + self._m.append(rng.randint(1, num_input, size=num_hidden)) + # self._m.append(np.arange(1, num_hidden + 1) % (num_input - 1) + 1) + if i == num_layers: + mask = self._m[i][None, :] > self._m[i - 1][:, None] + else: + # input to hidden & hidden to hidden + mask = self._m[i][None, :] >= self._m[i - 1][:, None] + # need to transpose for torch linear layer, shape (num_output, num_input) + self._masks.append(torch.from_numpy(mask.astype(np.float32).T)) + + def _check_masks(self): + """Check that the connectivity matrix between layers is lower triangular.""" + # (num_input, num_hidden) + prev = self._masks[0].t() + for i in range(1, len(self._masks)): + # num_hidden is second axis + prev = prev @ self._masks[i].t() + final = prev.numpy() + num_input = self._masks[0].shape[1] + num_output = self._masks[-1].shape[0] + assert final.shape == (num_input, num_output) + if num_output == num_input: + assert np.triu(final).all() == 0 + else: + for submat in np.split(final, + indices_or_sections=num_output // num_input, + axis=1): + assert np.triu(submat).all() == 0 + + def forward(self, input, context=None): + # first hidden layer receives input and context + hidden = self.input_context_net(input, context) + # rest of the network is conditioned on both input and context + return self.net(hidden) + + +class Reverse(nn.Module): + """ An implementation of a reversing layer from + Density estimation using Real NVP + (https://arxiv.org/abs/1605.08803). + + From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py + """ + + def __init__(self, num_input): + super(Reverse, self).__init__() + self.perm = np.array(np.arange(0, num_input)[::-1]) + self.inv_perm = np.argsort(self.perm) + + def forward(self, inputs, context=None, mode='forward'): + if mode == "forward": + return inputs[:, :, self.perm], torch.zeros_like(inputs, device=inputs.device) + elif mode == "inverse": + return inputs[:, :, self.inv_perm], torch.zeros_like(inputs, device=inputs.device) + else: + raise ValueError("Mode must be one of {forward, inverse}.") diff --git a/modules/utils.py b/modules/utils.py new file mode 100755 index 0000000..9e2d020 --- /dev/null +++ b/modules/utils.py @@ -0,0 +1,39 @@ +import torch + +def log_sum_exp(value, dim=None, keepdim=False): + """Numerically stable implementation of the operation + value.exp().sum(dim, keepdim).log() + """ + if dim is not None: + m, _ = torch.max(value, dim=dim, keepdim=True) + value0 = value - m + if keepdim is False: + m = m.squeeze(dim) + return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) + else: + m = torch.max(value) + sum_exp = torch.sum(torch.exp(value - m)) + return m + torch.log(sum_exp) + + +def generate_grid(zmin, zmax, dz, device, ndim=2): + """generate a 1- or 2-dimensional grid + Returns: Tensor, int + Tensor: The grid tensor with shape (k^2, 2), + where k=(zmax - zmin)/dz + int: k + """ + + if ndim == 2: + x = torch.arange(zmin, zmax, dz) + k = x.size(0) + + x1 = x.unsqueeze(1).repeat(1, k).view(-1) + x2 = x.repeat(k) + + return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k + + elif ndim == 1: + return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device) + + diff --git a/modules/vae.py b/modules/vae.py new file mode 100755 index 0000000..d5687bb --- /dev/null +++ b/modules/vae.py @@ -0,0 +1,303 @@ +import math +import torch +import torch.nn as nn + +from .utils import log_sum_exp + +class VAE(nn.Module): + """VAE with normal prior""" + def __init__(self, encoder, decoder, args): + super(VAE, self).__init__() + self.encoder = encoder + self.decoder = decoder + + self.args = args + + self.nz = args.nz + + loc = torch.zeros(self.nz, device=args.device) + scale = torch.ones(self.nz, device=args.device) + + self.prior = torch.distributions.normal.Normal(loc, scale) + + def encode(self, x, args,training=True): + """ + Returns: Tensor1, Tensor2 + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tensor2: the tenor of KL for each x with shape [batch] + """ + return self.encoder.encode(x,args,training) + + def encode_stats(self, x): + """ + Returns: Tensor1, Tensor2 + Tensor1: the mean of latent z with shape [batch, nz] + Tensor2: the logvar of latent z with shape [batch, nz] + """ + + return self.encoder(x) + + def decode(self, z, strategy, K=5): + """generate samples from z given strategy + + Args: + z: [batch, nsamples, nz] + strategy: "beam" or "greedy" or "sample" + K: the beam width parameter + + Returns: List1 + List1: a list of decoded word sequence + """ + + if strategy == "beam": + return self.decoder.beam_search_decode(z, K) + elif strategy == "greedy": + return self.decoder.greedy_decode(z) + elif strategy == "sample": + return self.decoder.sample_decode(z) + else: + raise ValueError("the decoding strategy is not supported") + + def reconstruct(self, x, decoding_strategy="greedy", K=5,beta=1): + """reconstruct from input x + + Args: + x: (batch, *) + decoding_strategy: "beam" or "greedy" or "sample" + K: the beam width parameter (if applicable) + + Returns: List1 + List1: a list of decoded word sequence + """ + z = self.sample_from_inference(x,beta=beta).squeeze(1) + + return self.decode(z, decoding_strategy, K) + + def loss(self, x, kl_weight, args,training=True): + """ + Args: + x: if the data is constant-length, x is the data tensor with + shape (batch, *). Otherwise x is a tuple that contains + the data tensor and length list + + Returns: Tensor1, Tensor2, Tensor3 + Tensor1: total loss [batch] + Tensor2: reconstruction loss shape [batch] + Tensor3: KL loss shape [batch] + """ + + z, KL= self.encode(x, args,training) + + # (batch) + reconstruct_err = self.decoder.reconstruct_error(x, z).mean(dim=1) + if torch.sum(torch.isnan(reconstruct_err)): + import ipdb + ipdb.set_trace() + return reconstruct_err + kl_weight * KL, reconstruct_err, KL + + def rc_loss(self, x, y, kl_weight, args): + z, KL = self.encode(x, args) + reconstruct_err = self.decoder.reconstruct_error(y, z).mean(dim=1) + return reconstruct_err + kl_weight * KL, reconstruct_err, KL + + def nll_iw(self, x, nsamples, ns=100): + """compute the importance weighting estimate of the log-likelihood + Args: + x: if the data is constant-length, x is the data tensor with + shape (batch, *). Otherwise x is a tuple that contains + the data tensor and length list + nsamples: Int + the number of samples required to estimate marginal data likelihood + Returns: Tensor1 + Tensor1: the estimate of log p(x), shape [batch] + """ + # compute iw every ns samples to address the memory issue + # nsamples = 500, ns = 100 + # nsamples = 500, ns = 10 + tmp = [] + for _ in range(int(nsamples / ns)): + # [batch, ns, nz] + # param is the parameters required to evaluate q(z|x) + z, param = self.encoder.sample(x, ns) + + # [batch, ns] + log_comp_ll = self.eval_complete_ll(x, z) + log_infer_ll = self.eval_inference_dist(x, z, param) + tmp.append(log_comp_ll - log_infer_ll) + + ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples) + + return -ll_iw + + def KL(self, x, args): + # _, KL = self.encode(x, 1) + mu, logvar = self.encoder.forward(x) + KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) + + return KL + + def eval_prior_dist(self, zrange): + """perform grid search to calculate the true posterior + Args: + zrange: tensor + different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/space + """ + + # (k^2) + return self.prior.log_prob(zrange).sum(dim=-1) + + def eval_complete_ll(self, x, z): + """compute log p(z,x) + Args: + x: Tensor + input with shape [batch, seq_len] + z: Tensor + evaluation points with shape [batch, nsamples, nz] + Returns: Tensor1 + Tensor1: log p(z,x) Tensor with shape [batch, nsamples] + """ + + # [batch, nsamples] + log_prior = self.eval_prior_dist(z) + log_gen = self.eval_cond_ll(x, z) + + return log_prior + log_gen + + def eval_cond_ll(self, x, z): + """compute log p(x|z) + """ + + return self.decoder.log_probability(x, z) + + def eval_log_model_posterior(self, x, grid_z): + """perform grid search to calculate the true posterior + this function computes p(z|x) + Args: + grid_z: tensor + different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/pace + + Returns: Tensor + Tensor: the log posterior distribution log p(z|x) with + shape [batch_size, K^2] + """ + try: + batch_size = x.size(0) + except: + batch_size = x[0].size(0) + + # (batch_size, k^2, nz) + grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous() + + # (batch_size, k^2) + log_comp = self.eval_complete_ll(x, grid_z) + + # normalize to posterior + log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True) + + return log_posterior + + def sample_from_prior(self, nsamples): + """sampling from prior distribution + + Returns: Tensor + Tensor: samples from prior with shape (nsamples, nz) + """ + return self.prior.sample((nsamples,)) + + def sample_from_inference(self, x, nsamples=1,beta=1): + """perform sampling from inference net + Returns: Tensor + Tensor: samples from infernece nets with + shape (batch_size, nsamples, nz) + """ + z, (mu,logvar) = self.encoder.sample(x, nsamples) + + return mu # ??????? + + def sample_from_posterior(self, x, nsamples): + """perform MH sampling from model posterior + Returns: Tensor + Tensor: samples from model posterior with + shape (batch_size, nsamples, nz) + """ + + # use the samples from inference net as initial points + # for MCMC sampling. [batch_size, nsamples, nz] + cur = self.encoder.sample_from_inference(x, 1) + cur_ll = self.eval_complete_ll(x, cur) + total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin + samples = [] + for iter_ in range(total_iter): + next = torch.normal(mean=cur, + std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std)) + # [batch_size, 1] + next_ll = self.eval_complete_ll(x, next) + ratio = next_ll - cur_ll + + accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size())) + + uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_() + + # [batch_size, 1] + mask = (uniform_t < accept_prob).float() + + mask_ = mask.unsqueeze(2) + + cur = mask_ * next + (1 - mask_) * cur + cur_ll = mask * next_ll + (1 - mask) * cur_ll + + if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0: + samples.append(cur.unsqueeze(1)) + + return torch.cat(samples, dim=1) + + def calc_model_posterior_mean(self, x, grid_z): + """compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z] + Args: + grid_z: different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/pace + x: [batch, *] + + Returns: Tensor1 + Tensor1: the mean value tensor with shape [batch, nz] + + """ + + # [batch, K^2] + log_posterior = self.eval_log_model_posterior(x, grid_z) + posterior = log_posterior.exp() + + # [batch, nz] + return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1) + + def calc_infer_mean(self, x): + """ + Returns: Tensor1 + Tensor1: the mean of inference distribution, with shape [batch, nz] + """ + + mean, logvar = self.encoder.forward(x) + # if self.args.gamma<0: + # mean,logvar=self.encoder.trans_param(mean,logvar) + + return mean + + def eval_inference_dist(self, x, z, param=None): + """ + Returns: Tensor + Tensor: the posterior density tensor with + shape (batch_size, nsamples) + """ + return self.encoder.eval_inference_dist(x, z, param) + + def calc_mi_q(self, x): + """Approximate the mutual information between x and z + under distribution q(z|x) + + Args: + x: [batch_size, *]. The sampled data to estimate mutual info + """ + + return self.encoder.calc_mi(x) diff --git a/modules/vae_IAF.py b/modules/vae_IAF.py new file mode 100644 index 0000000..4d09ca1 --- /dev/null +++ b/modules/vae_IAF.py @@ -0,0 +1,309 @@ +import math +import torch +import torch.nn as nn + +from .utils import log_sum_exp + +class VAEIAF(nn.Module): + """VAE with normal prior""" + def __init__(self, encoder, decoder, args): + super(VAEIAF, self).__init__() + self.encoder = encoder + self.decoder = decoder + + self.args = args + + self.nz = args.nz + + loc = torch.zeros(self.nz, device=args.device) + scale = torch.ones(self.nz, device=args.device) + + self.prior = torch.distributions.normal.Normal(loc, scale) + + def encode(self, x, args,training=True): + """ + Returns: Tensor1, Tensor2 + Tensor1: the tensor latent z with shape [batch, nsamples, nz] + Tensor2: the tenor of KL for each x with shape [batch] + """ + return self.encoder.encode(x, args) + + def encode_stats(self, x): + """ + Returns: Tensor1, Tensor2 + Tensor1: the mean of latent z with shape [batch, nz] + Tensor2: the logvar of latent z with shape [batch, nz] + """ + + return self.encoder.infer_param(x) + + def decode(self, z, strategy, K=5): + """generate samples from z given strategy + + Args: + z: [batch, nsamples, nz] + strategy: "beam" or "greedy" or "sample" + K: the beam width parameter + + Returns: List1 + List1: a list of decoded word sequence + """ + + if strategy == "beam": + return self.decoder.beam_search_decode(z, K) + elif strategy == "greedy": + return self.decoder.greedy_decode(z) + elif strategy == "sample": + return self.decoder.sample_decode(z) + else: + raise ValueError("the decoding strategy is not supported") + + def reconstruct(self, x, decoding_strategy="greedy", K=5,beta=1): + """reconstruct from input x + + Args: + x: (batch, *) + decoding_strategy: "beam" or "greedy" or "sample" + K: the beam width parameter (if applicable) + + Returns: List1 + List1: a list of decoded word sequence + """ + z = self.sample_from_inference(x,beta=beta).squeeze(1) + + return self.decode(z, decoding_strategy, K) + + def loss(self, x, kl_weight, args, training=True): + """ + Args: + x: if the data is constant-length, x is the data tensor with + shape (batch, *). Otherwise x is a tuple that contains + the data tensor and length list + + Returns: Tensor1, Tensor2, Tensor3 + Tensor1: total loss [batch] + Tensor2: reconstruction loss shape [batch] + Tensor3: KL loss shape [batch] + """ + + z, KL = self.encode(x, args, training) + + # (batch) + reconstruct_err = self.decoder.reconstruct_error(x, z).mean(dim=1) + if torch.sum(torch.isnan(reconstruct_err)): + import ipdb + ipdb.set_trace() + + return reconstruct_err + kl_weight * KL, reconstruct_err, KL + + def rc_loss(self, x, y, kl_weight, args): + z, KL = self.encode(x, args) + reconstruct_err = self.decoder.reconstruct_error(y, z).mean(dim=1) + return reconstruct_err + kl_weight * KL, reconstruct_err, KL + + def nll_iw(self, x, nsamples, ns=100): + """compute the importance weighting estimate of the log-likelihood + Args: + x: if the data is constant-length, x is the data tensor with + shape (batch, *). Otherwise x is a tuple that contains + the data tensor and length list + nsamples: Int + the number of samples required to estimate marginal data likelihood + Returns: Tensor1 + Tensor1: the estimate of log p(x), shape [batch] + """ + # compute iw every ns samples to address the memory issue + # nsamples = 500, ns = 100 + # nsamples = 500, ns = 10 + tmp = [] + for _ in range(int(nsamples / ns)): + # [batch, ns, nz] + # param is the parameters required to evaluate q(z|x) + z, log_infer_ll = self.encoder.sample(x, ns) + log_infer_ll = log_infer_ll.sum(dim=-1) + + # [batch, ns] + log_comp_ll = self.eval_complete_ll(x, z) + + tmp.append(log_comp_ll - log_infer_ll) + + ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples) + + + return -ll_iw + + def KL(self, x, args): + # _, KL = self.encode(x, 1) + z, KL = self.encode(x, args, training=False) + + return KL + + def eval_prior_dist(self, zrange): + """perform grid search to calculate the true posterior + Args: + zrange: tensor + different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/space + """ + + # (k^2) + return self.prior.log_prob(zrange).sum(dim=-1) + + def eval_complete_ll(self, x, z): + """compute log p(z,x) + Args: + x: Tensor + input with shape [batch, seq_len] + z: Tensor + evaluation points with shape [batch, nsamples, nz] + Returns: Tensor1 + Tensor1: log p(z,x) Tensor with shape [batch, nsamples] + """ + + # [batch, nsamples] + log_prior = self.eval_prior_dist(z) + log_gen = self.eval_cond_ll(x, z) + + return log_prior + log_gen + + def eval_cond_ll(self, x, z): + """compute log p(x|z) + """ + + return self.decoder.log_probability(x, z) + + def eval_log_model_posterior(self, x, grid_z): + """perform grid search to calculate the true posterior + this function computes p(z|x) + Args: + grid_z: tensor + different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/pace + + Returns: Tensor + Tensor: the log posterior distribution log p(z|x) with + shape [batch_size, K^2] + """ + try: + batch_size = x.size(0) + except: + batch_size = x[0].size(0) + + # (batch_size, k^2, nz) + grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous() + + # (batch_size, k^2) + log_comp = self.eval_complete_ll(x, grid_z) + + # normalize to posterior + log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True) + + return log_posterior + + def sample_from_prior(self, nsamples): + """sampling from prior distribution + + Returns: Tensor + Tensor: samples from prior with shape (nsamples, nz) + """ + return self.prior.sample((nsamples,)) + + def sample_from_inference(self, x, nsamples=1,beta=1): + """perform sampling from inference net + Returns: Tensor + Tensor: samples from infernece nets with + shape (batch_size, nsamples, nz) + """ + z, (mu,logvar) = self.encoder.sample(x, nsamples) + if beta==0: + z=mu + + return mu # ??????? + + def sample_from_posterior(self, x, nsamples): + """perform MH sampling from model posterior + Returns: Tensor + Tensor: samples from model posterior with + shape (batch_size, nsamples, nz) + """ + + # use the samples from inference net as initial points + # for MCMC sampling. [batch_size, nsamples, nz] + cur = self.encoder.sample_from_inference(x, 1) + cur_ll = self.eval_complete_ll(x, cur) + total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin + samples = [] + for iter_ in range(total_iter): + next = torch.normal(mean=cur, + std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std)) + # [batch_size, 1] + next_ll = self.eval_complete_ll(x, next) + ratio = next_ll - cur_ll + + accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size())) + + uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_() + + # [batch_size, 1] + mask = (uniform_t < accept_prob).float() + + mask_ = mask.unsqueeze(2) + + cur = mask_ * next + (1 - mask_) * cur + cur_ll = mask * next_ll + (1 - mask) * cur_ll + + if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0: + samples.append(cur.unsqueeze(1)) + + return torch.cat(samples, dim=1) + + def calc_model_posterior_mean(self, x, grid_z): + """compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z] + Args: + grid_z: different z points that will be evaluated, with + shape (k^2, nz), where k=(zmax - zmin)/pace + x: [batch, *] + + Returns: Tensor1 + Tensor1: the mean value tensor with shape [batch, nz] + + """ + + # [batch, K^2] + log_posterior = self.eval_log_model_posterior(x, grid_z) + posterior = log_posterior.exp() + + # [batch, nz] + return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1) + + def calc_infer_mean(self, x): + """ + Returns: Tensor1 + Tensor1: the mean of inference distribution, with shape [batch, nz] + """ + + mean, logvar = self.encoder.forward(x) + # if self.args.gamma<0: + # mean,logvar=self.encoder.trans_param(mean,logvar) + + return mean + + def eval_inference_dist(self, x, z, param=None): + """ + Returns: Tensor + Tensor: the posterior density tensor with + shape (batch_size, nsamples) + """ + + + return self.encoder.eval_inference_dist(x, z, param) + + def calc_mi_q(self, x): + """Approximate the mutual information between x and z + under distribution q(z|x) + + Args: + x: [batch_size, *]. The sampled data to estimate mutual info + """ + + return self.encoder.calc_mi(x) diff --git a/omniglotDataset.py b/omniglotDataset.py new file mode 100644 index 0000000..454fc18 --- /dev/null +++ b/omniglotDataset.py @@ -0,0 +1,150 @@ +import os.path +import numpy as np +import torch +import pickle + + + +def process_data(x,pad=15): + temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} + for (img, label) in x: + if label in temp.keys(): + temp[label].append(img) + else: + temp[label] = [img] + + for label in temp: + if len(temp[label]) < pad: + print(label, len(temp[label])) + sn = pad - len(temp[label]) + ids = np.random.choice(len(temp[label]), sn) + for id in ids: + temp[label].append(temp[label][id]) + + keys = list(temp.keys()) + keys = sorted(keys) + ttemp = dict() + for key in keys: + a, b = key + if a not in ttemp: + ttemp[a] = [] + ttemp[a].append(temp[key]) + for a in ttemp: + ttemp[a] = np.array(ttemp[a]) + return ttemp + +def feature(x,encoder,device, IAF =False): + dim0 = x.size(0) + dim1 = x.size(1) + tmp = x.reshape(-1, 1, 28, 28) + label = torch.zeros(tmp.size(0), 1) + tmp_data = torch.utils.data.TensorDataset(tmp, label) + loader = torch.utils.data.DataLoader(tmp_data, batch_size=256, shuffle=False) + feature = [] + for datum in loader: + batch_data, _ = datum + batch_data = batch_data.to(device) + if IAF: + mu, zT = encoder.learn_feature(batch_data) + mu = torch.cat([mu,zT],dim=-1) + else: + mu, _ = encoder(batch_data) + feature.append(mu.detach()) + x = torch.cat(feature, dim=0).reshape(dim0, dim1, -1) + return x + + +class Omniglot: + def __init__(self, root, encoder=None, device =None, IAF=False): + """ + Different from mnistNShot, the + :param root: + :param batchsz: task num + :param n_way: + :param k_shot: + :param k_qry: + :param imgsz: + """ + print(root) + if not os.path.isfile(os.path.join(root, 'omniglot_dataset.pkl')): + x_train, x_test = np.load(os.path.join(root, 'omniglot.npy'),allow_pickle=True) + + train_dict = process_data(x_train,pad=15) + test_dict = process_data(x_test,pad=5) + self.data_dict={} + for a in train_dict: + self.data_dict[a]={'train':train_dict[a],'test':test_dict[a]} + pickle.dump(self.data_dict, open(root + '/omniglot_dataset.pkl','wb')) + else: + print('load from meta_learning_dataset.pt.') + self.data_dict = pickle.load(open(root + '/omniglot_dataset.pkl','rb')) + + if encoder: + print('begin learning feature') + for a in range(50): + x_train = self.data_dict[a]['train'] + x_test = self.data_dict[a]['test'] + print(a,x_train.shape[0]) + x_train_s=[] + x_test_s =[] + x_train = torch.from_numpy(x_train).to(device) + x_test = torch.from_numpy(x_test).to(device) + # for _ in range(5): + x_train = torch.bernoulli(x_train) + # x_train_s.append(x_train) + x_test = torch.bernoulli(x_test) + # x_test_s.append(x_test) + # x_train = torch.cat(x_train_s,dim =1) + # x_test = torch.cat(x_test_s, dim =1 ) + + x_train = feature(x_train,encoder,device,IAF) + x_test = feature(x_test,encoder,device,IAF) + self.data_dict[a] = {'train': x_train, 'test': x_test} + print('Done!') + + def load_task(self,i, trainnum=10): + if i < 50: + x_train = self.data_dict[i]['train'] + x_test = self.data_dict[i]['test'] + elif i==50: + x_train_s =[] + x_test_s =[] + for a in range(50): + x_train_s.append(self.data_dict[a]['train']) + x_test_s.append(self.data_dict[a]['test']) + x_train = torch.cat(x_train_s,dim =0) + x_test = torch.cat(x_test_s,dim=0) + try : + x_train = x_train[:,:trainnum,:] + NC,N = x_train.size()[:2] + label = torch.tensor(range(NC)).unsqueeze(1).expand(-1,N) + x_train = x_train.reshape(NC*N,-1) + l_train = label.reshape(NC * N, -1) + + NC, N = x_test.size()[:2] + label = torch.tensor(range(NC)).unsqueeze(1).expand(-1, N) + x_test = x_test.reshape(NC * N, -1) + l_test = label.reshape(NC * N, -1) + + return x_train,l_train,x_test,l_test, NC + except: + NC, N = x_train.shape[:2] + ds = x_train.shape[2:] + label = np.array(range(NC))[:,np.newaxis].repeat(N,axis=1) + x_train = x_train.reshape(NC * N, *ds) + l_train = label.reshape(NC * N, -1) + + NC, N = x_test.shape[:2] + ds = x_test.shape[2:] + label = np.array(range(NC))[:,np.newaxis].repeat(N,axis=1) + x_test = x_test.reshape(NC * N, *ds) + l_test = label.reshape(NC * N, -1) + + return x_train, l_train, x_test, l_test, NC + +if __name__ == '__main__': + + root = 'data/omniglot_data/' + data = Omniglot(root) + + diff --git a/text.py b/text.py new file mode 100755 index 0000000..af92354 --- /dev/null +++ b/text.py @@ -0,0 +1,500 @@ +import sys +import os +import time +import importlib +import argparse + +import numpy as np + +import torch +from torch import nn, optim + +from data import MonoTextData, VocabEntry +from modules import VAE +from modules import LSTMEncoder, LSTMDecoder, GaussianLSTMEncoder +from logger import Logger +from utils import calc_mi + + +clip_grad = 5.0 +decay_epoch = 5 +lr_decay = 0.5 +max_decay = 5 + + +def init_config(): + parser = argparse.ArgumentParser(description='VAE mode collapse study') + # model hyperparameters + parser.add_argument('--dataset', default='synthetic', type=str, help='dataset to use') + # optimization parameters + parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') + parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') + parser.add_argument('--iw_nsamples', type=int, default=500, + help='number of samples to compute importance weighted estimate') + # select mode + parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') + parser.add_argument('--load_path', type=str, default='') + # annealing paramters + parser.add_argument('--warm_up', type=int, default=10, help="number of annealing epochs") + parser.add_argument('--kl_start', type=float, default=0.0, help="starting KL weight") + # these are for slurm purpose to save model + parser.add_argument('--jobid', type=int, default=0, help='slurm job id') + parser.add_argument('--taskid', type=int, default=0, help='slurm task id') + parser.add_argument('--device', type=str, default="cpu") + parser.add_argument('--delta_rate', type=float, default=1, + help=" coontrol the minization of the variation of latent variables") + + # new + parser.add_argument("--target_kl", type=float, default=-1, + help="target kl of the free bits trick") + parser.add_argument('--gamma', type=float, default=1.0) # BN-VAE + parser.add_argument("--reset_dec", action="store_true", default=False) + parser.add_argument("--nz_new", type=int, default=32) + parser.add_argument('--p_drop', type=float, default=0.5) # p \in [0, 1] + parser.add_argument('--lr', type=float, default=1.0) # delta-VAE + args = parser.parse_args() + + if 'cuda' in args.device: + args.cuda = True + else: + args.cuda = False + + load_str = "_load" if args.load_path != "" else "" + load_str += '_eval' if args.eval else '' + save_dir = "models/%s%s/" % (args.dataset, load_str) + + if args.warm_up > 0 and args.kl_start < 1.0: + cw_str = '_warm%d' % args.warm_up + else: + cw_str = '' + + hkl_str = 'KL%.2f' % args.kl_start + drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else '' + seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] + args.seed = seed_set[args.taskid] + + if args.gamma > 0: + gamma_str = '_gamma%.2f' % (args.gamma) + else: + gamma_str = '' + + momentum_str = '_m%.2f' % args.momentum if args.momentum > 0 else '' + id_ = "%s_%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d%s_lr%.1f" % \ + (args.dataset, hkl_str, + cw_str, load_str, gamma_str,args.delta_rate, + args.nz_new, drop_str, + args.jobid, args.taskid, args.seed, momentum_str, args.lr) + + save_dir += id_ + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + save_path = os.path.join(save_dir, 'model.pt') + + args.save_path = save_path + print("save path", args.save_path) + + args.log_path = os.path.join(save_dir, "log.txt") + print("log path", args.log_path) + + # load config file into args + config_file = "config.config_%s" % args.dataset + params = importlib.import_module(config_file).params + args = argparse.Namespace(**vars(args), **params) + if args.nz != args.nz_new: + args.nz = args.nz_new + + if 'label' in params: + args.label = params['label'] + else: + args.label = False + if 'vocab_file' not in params: + args.vocab_file = None + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + + return args + + +def interploation(model, data, strategy, device): + for i, (batch_data, sent_len) in enumerate(data.data_iter(batch_size=2, device=device, + batch_first=True, shuffle=False)): + + z = model.sample_from_inference(batch_data).squeeze(1) + import ipdb + ipdb.set_trace() + z1 = z[[0], :] + z2 = z[[1], :] + ipdb.set_trace() + zs = [] + for alpha in [0, 0.2, 0.4, 0.6, 0.8, 1]: + z = alpha * z2 + (1 - alpha) * z1 + zs.append(z) + zs = torch.cat(zs, dim=0) + decoded_batch = model.decode(zs, strategy) + for sent in decoded_batch: + print(' '.join(sent) + '\n') + if i >= 10: + return + + +def sample_from_prior(model, z, strategy, fname): + with open(fname, "w") as fout: + decoded_batch = model.decode(z, strategy) + + for sent in decoded_batch: + fout.write(" ".join(sent) + "\n") + + +def test(model, test_data_batch, mode, args, verbose=True): + report_kl_loss = report_kl_t_loss = report_rec_loss = 0 + report_num_words = report_num_sents = 0 + try: + print(model.encoder.theta) + theta = model.encoder.theta + print('+', torch.sum((theta > 0))) + print('-', torch.sum((theta < 0))) + except: + pass + for i in np.random.permutation(len(test_data_batch)): + batch_data = test_data_batch[i] + batch_size, sent_len = batch_data.size() + + # not predict start symbol + report_num_words += (sent_len - 1) * batch_size + + report_num_sents += batch_size + + loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False) + loss_kl_t = model.KL(batch_data, args) + assert (not loss_rc.requires_grad) + assert (not loss_kl.requires_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + loss_kl_t = loss_kl_t.sum() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + report_kl_t_loss += loss_kl_t.item() + + mutual_info = calc_mi(model, test_data_batch, device=args.device) + + test_loss = (report_rec_loss + report_kl_loss) / report_num_sents + + nll = (report_kl_t_loss + report_rec_loss) / report_num_sents + kl = report_kl_loss / report_num_sents + kl_t = report_kl_t_loss / report_num_sents + ppl = np.exp(nll * report_num_sents / report_num_words) + if verbose: + print('%s --- avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \ + (mode, test_loss, report_kl_t_loss / report_num_sents, mutual_info, + report_rec_loss / report_num_sents, nll, ppl)) + sys.stdout.flush() + + return test_loss, nll, kl_t, ppl, mutual_info # 返回真实的kl_t 不是训练中的kl + + +def calc_iwnll(model, test_data_batch, args, ns=100): + report_nll_loss = 0 + report_num_words = report_num_sents = 0 + for id_, i in enumerate(np.random.permutation(len(test_data_batch))): + batch_data = test_data_batch[i] + batch_size, sent_len = batch_data.size() + + # not predict start symbol + report_num_words += (sent_len - 1) * batch_size + + report_num_sents += batch_size + if id_ % (round(len(test_data_batch) / 10)) == 0: + print('iw nll computing %d0%%' % (id_ / (round(len(test_data_batch) / 10)))) + sys.stdout.flush() + + loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples, ns=ns) + + report_nll_loss += loss.sum().item() + + nll = report_nll_loss / report_num_sents + ppl = np.exp(nll * report_num_sents / report_num_words) + + print('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) + sys.stdout.flush() + return nll, ppl + + +def calc_au(model, test_data_batch, delta=0.01): + """compute the number of active units + """ + cnt = 0 + for batch_data in test_data_batch: + mean, _ = model.encode_stats(batch_data) + if cnt == 0: + means_sum = mean.sum(dim=0, keepdim=True) + else: + means_sum = means_sum + mean.sum(dim=0, keepdim=True) + cnt += mean.size(0) + + # (1, nz) + mean_mean = means_sum / cnt + + cnt = 0 + for batch_data in test_data_batch: + mean, _ = model.encode_stats(batch_data) + if cnt == 0: + var_sum = ((mean - mean_mean) ** 2).sum(dim=0) + else: + var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0) + cnt += mean.size(0) + + # (nz) + au_var = var_sum / (cnt - 1) + + return (au_var >= delta).sum().item(), au_var + + +def main(args): + class uniform_initializer(object): + def __init__(self, stdv): + self.stdv = stdv + + def __call__(self, tensor): + nn.init.uniform_(tensor, -self.stdv, self.stdv) + + if args.cuda: + print('using cuda') + + print(args) + + opt_dict = {"not_improved": 0, "lr": args.lr, "best_loss": 1e4} + if args.vocab_file is not None: + print(args.vocab_file) + vocab = {} + with open(args.vocab_file) as fvocab: + for i, line in enumerate(fvocab): + vocab[line.strip()] = i + + vocab = VocabEntry(vocab) + train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) + else: + train_data = MonoTextData(args.train_data, label=args.label) + + vocab = train_data.vocab + vocab_size = len(vocab) + + val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) + test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) + + print('Train data: %d samples' % len(train_data)) + print('finish reading datasets, vocab size is %d' % len(vocab)) + print('dropped sentences: %d' % train_data.dropped) + sys.stdout.flush() + + log_niter = (len(train_data) // args.batch_size) // 10 + + model_init = uniform_initializer(0.01) + emb_init = uniform_initializer(0.1) + + args.device = torch.device(args.device) + device = args.device + + if args.gamma > 0 and args.enc_type == 'lstm': + encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) + args.enc_nh = args.dec_nh + elif args.gamma == 0 and args.enc_type == 'lstm': + encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) + args.enc_nh = args.dec_nh + else: + raise ValueError("the specified encoder type is not supported") + + decoder = LSTMDecoder(args, vocab, model_init, emb_init) + + vae = VAE(encoder, decoder, args).to(device) + + if args.load_path: + loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device)) + vae.load_state_dict(loaded_state_dict, strict=False) + print("%s loaded" % args.load_path) + + if args.reset_dec: + if args.gamma > 0: + vae.encoder.reset_parameters(model_init, emb_init, reset=True) + print("\n-------reset decoder-------\n") + vae.decoder.reset_parameters(model_init, emb_init) + + if args.eval: + print('begin evaluation') + args.kl_weight = 1 + vae.load_state_dict(torch.load(args.load_path, map_location=torch.device(device))) + vae.eval() + with torch.no_grad(): + test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + test(vae, test_data_batch, "TEST", args) + au, au_var = calc_au(vae, test_data_batch) + print("%d active units" % au) + # print(au_var) + test_data_batch = test_data.create_data_batch(batch_size=1, + device=device, + batch_first=True) + calc_iwnll(vae, test_data_batch, args) + return + + enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) + dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) + opt_dict['lr'] = args.lr + + iter_ = decay_cnt = 0 + best_loss = 1e4 + vae.train() + start = time.time() + kl_weight = args.kl_start + if args.warm_up > 0 and args.kl_start < 1.0: + anneal_rate = (1.0 - args.kl_start) / ( + args.warm_up * (len(train_data) / args.batch_size)) # kl_start ==0 时 anneal_rate==0 + else: + anneal_rate = 0 + + train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + + val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + + test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + for epoch in range(args.epochs): + report_kl_loss = report_rec_loss = 0 + report_num_words = report_num_sents = 0 + for i in np.random.permutation(len(train_data_batch)): # len(train_data_batch) + batch_data = train_data_batch[i] + batch_size, sent_len = batch_data.size() + if batch_data.size(0) < 16: + continue + # not predict start symbol + report_num_words += (sent_len - 1) * batch_size + + report_num_sents += batch_size + + # kl_weight = 1.0 + if args.warm_up > 0 and args.kl_start < 1.0: + kl_weight = min(1.0, kl_weight + anneal_rate) + else: + kl_weight = 1.0 + + args.kl_weight = kl_weight + + enc_optimizer.zero_grad() + dec_optimizer.zero_grad() + loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args) + # loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) + + loss = loss.mean(dim=0) + + loss.backward() + torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + + enc_optimizer.step() + dec_optimizer.step() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + + if iter_ % log_niter == 0: + train_loss = (report_rec_loss + report_kl_loss) / report_num_sents + if epoch == 0: + vae.eval() + with torch.no_grad(): + mi = calc_mi(vae, val_data_batch, device=device) + au, _ = calc_au(vae, val_data_batch) + vae.train() + + print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f,' \ + 'au %d, time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, + report_rec_loss / report_num_sents, au, time.time() - start)) + else: + print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, recon: %.4f,' \ + 'time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_sents, + report_rec_loss / report_num_sents, time.time() - start)) + + sys.stdout.flush() + + report_rec_loss = report_kl_loss = 0 + report_num_words = report_num_sents = 0 + + iter_ += 1 + + print('kl weight %.4f' % args.kl_weight) + + vae.eval() + with torch.no_grad(): + loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) + au, au_var = calc_au(vae, val_data_batch) + print("%d active units" % au) + # print(au_var) + + if loss < best_loss: + print('update best loss') + best_loss = loss + torch.save(vae.state_dict(), args.save_path) + # torch.save(vae.state_dict(), args.save_path) + if loss > opt_dict["best_loss"]: + opt_dict["not_improved"] += 1 + if opt_dict["not_improved"] >= decay_epoch and epoch >= 15: + opt_dict["best_loss"] = loss + opt_dict["not_improved"] = 0 + opt_dict["lr"] = opt_dict["lr"] * lr_decay + vae.load_state_dict(torch.load(args.save_path)) + print('new lr: %f' % opt_dict["lr"]) + decay_cnt += 1 + enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) + dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) + + else: + opt_dict["not_improved"] = 0 + opt_dict["best_loss"] = loss + + if decay_cnt == max_decay: + break + + if epoch % args.test_nepoch == 0: + with torch.no_grad(): + loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) + + vae.train() + + # compute importance weighted estimate of log p(x) + vae.load_state_dict(torch.load(args.save_path)) + + vae.eval() + with torch.no_grad(): + loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) + au, au_var = calc_au(vae, test_data_batch) + print("%d active units" % au) + # print(au_var) + + test_data_batch = test_data.create_data_batch(batch_size=1, + device=device, + batch_first=True) + with torch.no_grad(): + calc_iwnll(vae, test_data_batch, args) + return args + + +if __name__ == '__main__': + args = init_config() + if not args.eval: + sys.stdout = Logger(args.log_path) + main(args) diff --git a/text_IAF.py b/text_IAF.py new file mode 100644 index 0000000..79a30f4 --- /dev/null +++ b/text_IAF.py @@ -0,0 +1,447 @@ +import sys +import os +import time +import importlib +import argparse +import numpy as np +import torch +from torch import nn, optim +from data import MonoTextData, VocabEntry +from modules import VAEIAF as VAE +from modules import VariationalFlow, LSTMDecoder +from logger import Logger +from utils import calc_mi, calc_au + +clip_grad = 5.0 +decay_epoch = 5 +lr_decay = 0.5 +max_decay = 5 + + +def init_config(): + parser = argparse.ArgumentParser(description='VAE-IAF mode collapse study') + + # model hyperparameters + parser.add_argument('--dataset', default='synthetic', type=str, help='dataset to use') + + # optimization parameters + parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') + parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') + parser.add_argument('--iw_nsamples', type=int, default=500, + help='number of samples to compute importance weighted estimate') + + # select mode + parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') + parser.add_argument('--load_path', type=str, default='') + + # annealing paramters + parser.add_argument('--warm_up', type=int, default=100, help="number of annealing epochs") + parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight") + + # these are for slurm purpose to save model + parser.add_argument('--jobid', type=int, default=0, help='slurm job id') + parser.add_argument('--taskid', type=int, default=0, help='slurm task id') + parser.add_argument('--device', type=str, default="cpu") + parser.add_argument('--delta_rate', type=float, default=1.0, + help=" coontrol the minization of the variation of latent variables") + + parser.add_argument('--gamma', type=float, default=0.8) # BN-VAE + + parser.add_argument("--nz_new", type=int, default=32) + + parser.add_argument('--p_drop', type=float, default=0.3) # p \in [0, 1] + parser.add_argument('--lr', type=float, default=1.0) # delta-VAE + + parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow") + parser.add_argument('--flow_width', type=int, default=2, help="width of flow") + + parser.add_argument("--fb", type=int, default=1, + help="0: no fb; 1: fb;") + + parser.add_argument("--target_kl", type=float, default=0.0, + help="target kl of the free bits trick") + + parser.add_argument('--drop_start', type=float, default=1.0, help="starting KL weight") + + args = parser.parse_args() + + if 'cuda' in args.device: + args.cuda = True + else: + args.cuda = False + + load_str = "_load" if args.load_path != "" else "" + save_dir = "models/%s%s/" % (args.dataset, load_str) + + if args.warm_up > 0 and args.kl_start < 1.0: + cw_str = '_warm%d' % args.warm_up + '_%.2f' % args.kl_start + else: + cw_str = '' + + if args.fb == 0: + fb_str = "" + elif args.fb in [1, 2]: + fb_str = "_fb%d_tr%.2f" % (args.fb, args.target_kl) + + else: + fb_str = '' + + drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else '' + if 1.0 > args.drop_start > 0 and args.p_drop != 0: + drop_str += '_start%.2f' % args.drop_start + + seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] + args.seed = seed_set[args.taskid] + + if args.gamma > 0: + gamma_str = '_gamma%.2f' % (args.gamma) + else: + gamma_str = '' + + if args.flow_depth > 0: + fd_str = '_fd%d_fw%d' % (args.flow_depth, args.flow_width) + + momentum_str = '_m%.2f' % args.momentum if args.momentum > 0 else '' + id_ = "%s%s%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d%s_lr%.1f_IAF" % \ + (args.dataset, cw_str, load_str, gamma_str, fb_str, fd_str, + args.delta_rate, args.nz_new, drop_str, + args.jobid, args.taskid, args.seed, momentum_str, args.lr) + save_dir += id_ + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + save_path = os.path.join(save_dir, 'model.pt') + + args.save_path = save_path + print("save path", args.save_path) + + args.log_path = os.path.join(save_dir, "log.txt") + print("log path", args.log_path) + + # load config file into args + config_file = "config.config_%s" % args.dataset + params = importlib.import_module(config_file).params + args = argparse.Namespace(**vars(args), **params) + if args.nz != args.nz_new: + args.nz = args.nz_new + + if 'label' in params: + args.label = params['label'] + else: + args.label = False + if 'vocab_file' not in params: + args.vocab_file = None + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + + return args + + +def test(model, test_data_batch, mode, args, verbose=True): + report_kl_loss = report_kl_t_loss = report_rec_loss = 0 + report_num_words = report_num_sents = 0 + + for i in np.random.permutation(len(test_data_batch)): + batch_data = test_data_batch[i] + batch_size, sent_len = batch_data.size() + + # not predict start symbol + report_num_words += (sent_len - 1) * batch_size + + report_num_sents += batch_size + + loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False) + loss_kl_t = model.KL(batch_data, args) + assert (not loss_rc.requires_grad) + assert (not loss_kl.requires_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + loss_kl_t = loss_kl_t.sum() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + report_kl_t_loss += loss_kl_t.item() + + mutual_info = calc_mi(model, test_data_batch, device=args.device) + + test_loss = (report_rec_loss + report_kl_loss) / report_num_sents + + nll = (report_kl_t_loss + report_rec_loss) / report_num_sents + kl = report_kl_loss / report_num_sents + kl_t = report_kl_t_loss / report_num_sents + ppl = np.exp(nll * report_num_sents / report_num_words) + if verbose: + print('%s --- avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \ + (mode, test_loss, report_kl_t_loss / report_num_sents, mutual_info, + report_rec_loss / report_num_sents, nll, ppl)) + sys.stdout.flush() + + return test_loss, nll, kl_t, ppl, mutual_info # 返回真实的kl_t 不是训练中的kl + + +def calc_iwnll(model, test_data_batch, args, ns=100): + report_nll_loss = 0 + report_num_words = report_num_sents = 0 + for id_, i in enumerate(np.random.permutation(len(test_data_batch))): + batch_data = test_data_batch[i] + batch_size, sent_len = batch_data.size() + + # not predict start symbol + report_num_words += (sent_len - 1) * batch_size + + report_num_sents += batch_size + if id_ % (round(len(test_data_batch) / 10)) == 0: + print('iw nll computing %d0%%' % (id_ / (round(len(test_data_batch) / 10)))) + sys.stdout.flush() + + loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples, ns=ns) + + report_nll_loss += loss.sum().item() + + nll = report_nll_loss / report_num_sents + ppl = np.exp(nll * report_num_sents / report_num_words) + + print('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) + sys.stdout.flush() + return nll, ppl + + +def main(args): + class uniform_initializer(object): + def __init__(self, stdv): + self.stdv = stdv + + def __call__(self, tensor): + nn.init.uniform_(tensor, -self.stdv, self.stdv) + + if args.cuda: + print('using cuda') + + print(args) + + opt_dict = {"not_improved": 0, "lr": args.lr, "best_loss": 1e4} + + if args.vocab_file is not None: + print(args.vocab_file) + vocab = {} + with open(args.vocab_file) as fvocab: + for i, line in enumerate(fvocab): + vocab[line.strip()] = i + vocab = VocabEntry(vocab) + train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) + else: + train_data = MonoTextData(args.train_data, label=args.label) + + vocab = train_data.vocab + + vocab_size = len(vocab) + + val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) + test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) + + print('Train data: %d samples' % len(train_data)) + print('finish reading datasets, vocab size is %d' % len(vocab)) + print('dropped sentences: %d' % train_data.dropped) + sys.stdout.flush() + + log_niter = (len(train_data) // args.batch_size) // 10 + + model_init = uniform_initializer(0.01) + emb_init = uniform_initializer(0.1) + + args.device = torch.device(args.device) + device = args.device + + encoder = VariationalFlow(args, vocab_size, model_init, emb_init) + args.enc_nh = args.dec_nh + + decoder = LSTMDecoder(args, vocab, model_init, emb_init) + + vae = VAE(encoder, decoder, args).to(device) + + if args.load_path: + loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device)) + vae.load_state_dict(loaded_state_dict, strict=False) + print("%s loaded" % args.load_path) + + + if args.eval: + print('begin evaluation') + vae.load_state_dict(torch.load(args.load_path, map_location=torch.device(device))) + vae.eval() + with torch.no_grad(): + test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + test(vae, test_data_batch, "TEST", args) + au, au_var = calc_au(vae, test_data_batch) + print("%d active units" % au) + # print(au_var) + test_data_batch = test_data.create_data_batch(batch_size=1, + device=device, + batch_first=True) + calc_iwnll(vae, test_data_batch, args) + return + + enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) + dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) + opt_dict['lr'] = args.lr + + iter_ = decay_cnt = 0 + best_loss = 1e4 + vae.train() + start = time.time() + + kl_weight = args.kl_start + if args.warm_up > 0 and args.kl_start < 1.0: + anneal_rate = (1.0 - args.kl_start) / ( + args.warm_up * (len(train_data) / args.batch_size)) # kl_start ==0 时 anneal_rate==0 + else: + anneal_rate = 0 + + train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + + val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + + test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, + device=device, + batch_first=True) + for epoch in range(args.epochs): + report_kl_loss = report_rec_loss = 0 + report_num_words = report_num_sents = 0 + for i in np.random.permutation(len(train_data_batch)): # len(train_data_batch) + batch_data = train_data_batch[i] + batch_size, sent_len = batch_data.size() + if batch_data.size(0) < 16: + continue + + # not predict start symbol + report_num_words += (sent_len - 1) * batch_size + + report_num_sents += batch_size + + # kl_weight = 1.0 + if args.warm_up > 0 and args.kl_start < 1.0: + kl_weight = min(1.0, kl_weight + anneal_rate) + else: + kl_weight = 1.0 + + args.kl_weight = kl_weight + + enc_optimizer.zero_grad() + dec_optimizer.zero_grad() + if args.fb == 0 or args.fb == 1: + loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args) + + loss = loss.mean(dim=0) + + loss.backward() + torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) + + loss_rc = loss_rc.sum() + loss_kl = loss_kl.sum() + + enc_optimizer.step() + dec_optimizer.step() + + report_rec_loss += loss_rc.item() + report_kl_loss += loss_kl.item() + + if iter_ % log_niter == 0: + train_loss = (report_rec_loss + report_kl_loss) / report_num_sents + if epoch == 0: + vae.eval() + with torch.no_grad(): + mi = calc_mi(vae, val_data_batch, device=device) + au, _ = calc_au(vae, val_data_batch) + vae.train() + + print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f,' \ + 'au %d, time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, + report_rec_loss / report_num_sents, au, time.time() - start)) + else: + print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, recon: %.4f,' \ + 'time elapsed %.2fs' % + (epoch, iter_, train_loss, report_kl_loss / report_num_sents, + report_rec_loss / report_num_sents, time.time() - start)) + + sys.stdout.flush() + + report_rec_loss = report_kl_loss = 0 + report_num_words = report_num_sents = 0 + + iter_ += 1 + + print('kl weight %.4f' % args.kl_weight) + + vae.eval() + with torch.no_grad(): + loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) + au, au_var = calc_au(vae, val_data_batch) + print("%d active units" % au) + # print(au_var) + + if loss < best_loss: + print('update best loss') + best_loss = loss + torch.save(vae.state_dict(), args.save_path) + # torch.save(vae.state_dict(), args.save_path) + if loss > opt_dict["best_loss"]: + opt_dict["not_improved"] += 1 + if opt_dict["not_improved"] >= decay_epoch and epoch >= 15: + opt_dict["best_loss"] = loss + opt_dict["not_improved"] = 0 + opt_dict["lr"] = opt_dict["lr"] * lr_decay + vae.load_state_dict(torch.load(args.save_path)) + print('new lr: %f' % opt_dict["lr"]) + decay_cnt += 1 + enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) + dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) + + else: + opt_dict["not_improved"] = 0 + opt_dict["best_loss"] = loss + + if decay_cnt == max_decay: + break + + if epoch % args.test_nepoch == 0: + with torch.no_grad(): + loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) + + vae.train() + + # compute importance weighted estimate of log p(x) + vae.load_state_dict(torch.load(args.save_path)) + + vae.eval() + with torch.no_grad(): + loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) + au, au_var = calc_au(vae, test_data_batch) + print("%d active units" % au) + # print(au_var) + + test_data_batch = test_data.create_data_batch(batch_size=1, + device=device, + batch_first=True) + with torch.no_grad(): + calc_iwnll(vae, test_data_batch, args) + return args + + +if __name__ == '__main__': + args = init_config() + if not args.eval: + sys.stdout = Logger(args.log_path) + main(args) diff --git a/text_ss.py b/text_ss.py new file mode 100644 index 0000000..06505aa --- /dev/null +++ b/text_ss.py @@ -0,0 +1,423 @@ +import os +import time +import importlib +import argparse + +import numpy as np + +import torch +from torch import nn, optim +# import swats + +from collections import defaultdict + +from data import MonoTextData, VocabEntry +from modules import VAE,LinearDiscriminator_only +from modules import GaussianLSTMEncoder, LSTMEncoder, LSTMDecoder, VariationalFlow + +from exp_utils import create_exp_dir +from utils import uniform_initializer + +# old parameters +clip_grad = 5.0 +decay_epoch = 2 +lr_decay = 0.5 +max_decay = 5 + +# Junxian's new parameters +# clip_grad = 1.0 +# decay_epoch = 5 +# lr_decay = 0.8 +# max_decay = 10 + +logging = None + + +def init_config(): + parser = argparse.ArgumentParser(description='VAE mode collapse study') + parser.add_argument('--gamma', type=float, default=0.0) + # model hyperparameters + parser.add_argument('--dataset', default='yelp', type=str, help='dataset to use') + # optimization parameters + parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') + parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="sgd", help='sgd momentum') + + parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') + parser.add_argument('--iw_nsamples', type=int, default=500, + help='number of samples to compute importance weighted estimate') + + # select mode + parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') + parser.add_argument('--load_path', type=str, + default='short_yelp_aggressive0_hs1.00_warm100_0_0_783435.pt') # TODO: 设定load_path + + # annealing paramters + parser.add_argument('--warm_up', type=int, default=100, + help="number of annealing epochs. warm_up=0 means not anneal") + parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight") + + # output directory + parser.add_argument('--exp_dir', default=None, type=str, + help='experiment directory.') + parser.add_argument("--save_ckpt", type=int, default=0, + help="save checkpoint every epoch before this number") + parser.add_argument("--save_latent", type=int, default=0) + + # new + parser.add_argument("--reset_dec", action="store_true", default=True) + parser.add_argument("--load_best_epoch", type=int, default=0) + parser.add_argument("--lr", type=float, default=1.) + + parser.add_argument("--batch_size", type=int, default=16, + help="target kl of the free bits trick") + parser.add_argument("--epochs", type=int, default=100, + help="number of epochs") + parser.add_argument("--update_every", type=int, default=1, + help="target kl of the free bits trick") + parser.add_argument("--num_label", type=int, default=100, + help="target kl of the free bits trick") + parser.add_argument("--freeze_enc", action="store_true", + default=True) # True-> freeze the parameters of vae.encoder + parser.add_argument("--discriminator", type=str, default="linear") + + parser.add_argument('--taskid', type=int, default=0, help='slurm task id') + parser.add_argument('--device', type=str, default="cuda:0") + parser.add_argument('--delta_rate', type=float, default=0.0, + help=" coontrol the minization of the variation of latent variables") + + parser.add_argument('--p_drop', type=float, default=0) + parser.add_argument('--IAF', action='store_true', default=False) + parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow") + parser.add_argument('--flow_width', type=int, default=60, help="width of flow") + + args = parser.parse_args() + + # set args.cuda + if 'cuda' in args.device: + args.cuda = True + else: + args.cuda = False + + # set seeds + seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] + args.seed = seed_set[args.taskid] + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + if args.cuda: + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + + # load config file into args + config_file = "config.config_%s" % args.dataset + if args.num_label == 10: + params = importlib.import_module(config_file).params_ss_10 + elif args.num_label == 100: + params = importlib.import_module(config_file).params_ss_100 + elif args.num_label == 500: + params = importlib.import_module(config_file).params_ss_500 + elif args.num_label == 1000: + params = importlib.import_module(config_file).params_ss_1000 + elif args.num_label == 2000: + params = importlib.import_module(config_file).params_ss_2000 + elif args.num_label == 10000: + params = importlib.import_module(config_file).params_ss_10000 + + args = argparse.Namespace(**vars(args), **params) + + load_str = "_load" if args.load_path != "" else "" + + opt_str = "_adam" if args.opt == "adam" else "_sgd" + nlabel_str = "_nlabel{}".format(args.num_label) + freeze_str = "_freeze" if args.freeze_enc else "" + + if len(args.load_path.split("/")) > 2: + load_path_str = args.load_path.split("/")[2] + else: + load_path_str = args.load_path.split("/")[1] + + model_str = "_{}".format(args.discriminator) + # set load and save paths + if args.exp_dir is None: + args.exp_dir = "models/exp_{}{}_ss_ft/{}{}{}{}{}".format(args.dataset, + load_str, load_path_str, model_str, opt_str, + nlabel_str, freeze_str) + + if len(args.load_path) <= 0 and args.eval: + args.load_path = os.path.join(args.exp_dir, 'model.pt') + + args.save_path = os.path.join(args.exp_dir, 'model.pt') + + # set args.label + if 'label' in params: + args.label = params['label'] + else: + args.label = False + + args.kl_weight = 1 + + return args + + +def test(model, test_data_batch, test_labels_batch, mode, args, verbose=True): + global logging + + report_correct = report_loss = 0 + report_num_sents = 0 + for i in np.random.permutation(len(test_data_batch)): + batch_data = test_data_batch[i] + batch_labels = test_labels_batch[i] + batch_labels = [int(x) for x in batch_labels] + + batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=args.device) + + batch_size = batch_data.size(0) + + # not predict start symbol + report_num_sents += batch_size + + loss, correct = model.get_performance_with_feature(batch_data, batch_labels) + + loss = loss.sum() + + report_loss += loss.item() + report_correct += correct + + test_loss = report_loss / report_num_sents + acc = report_correct / report_num_sents + + if verbose: + logging('%s --- avg_loss: %.4f, acc: %.4f' % \ + (mode, test_loss, acc)) + # sys.stdout.flush() + + return test_loss, acc + +def main(args): + global logging + logging = create_exp_dir(args.exp_dir, scripts_to_save=[]) + + if args.cuda: + logging('using cuda') + logging(str(args)) + + opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} + + vocab = {} + with open(args.vocab_file) as fvocab: + for i, line in enumerate(fvocab): + vocab[line.strip()] = i + + vocab = VocabEntry(vocab) + + train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) + + vocab_size = len(vocab) + + val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) + test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) + + logging('Train data: %d samples' % len(train_data)) + logging('finish reading datasets, vocab size is %d' % len(vocab)) + logging('dropped sentences: %d' % train_data.dropped) + # sys.stdout.flush() + + log_niter = max(1, (len(train_data) // (args.batch_size * args.update_every)) // 10) + + model_init = uniform_initializer(0.01) + emb_init = uniform_initializer(0.1) + + # device = torch.device("cuda" if args.cuda else "cpu") + # device = "cuda" if args.cuda else "cpu" + device = args.device + + if args.gamma > 0 and args.enc_type == 'lstm' and not args.IAF: + encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) + args.enc_nh = args.dec_nh + elif args.gamma == 0 and args.enc_type == 'lstm' and not args.IAF: + encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) + args.enc_nh = args.dec_nh + elif args.IAF: + encoder = VariationalFlow(args,vocab_size, model_init, emb_init) + args.enc_nh = args.dec_nh + else: + raise ValueError("the specified encoder type is not supported") + + decoder = LSTMDecoder(args, vocab, model_init, emb_init) + + vae = VAE(encoder, decoder, args).to(device) + vae.eval() + + if args.load_path: + loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device)) + vae.load_state_dict(loaded_state_dict) + logging("%s loaded" % args.load_path) + + try: + print('theta', vae.encoder.theta) + except: + pass + if args.discriminator == "linear": + discriminator = LinearDiscriminator_only(args, args.ncluster).to(device) + # elif args.discriminator == "mlp": + # discriminator = MLPDiscriminator(args, vae.encoder).to(device) + + if args.opt == "sgd": + optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) + opt_dict['lr'] = args.lr + elif args.opt == "adam": + optimizer = optim.Adam(discriminator.parameters(), lr=0.001) + # optimizer = swats.SWATS(discriminator.parameters(), lr=0.001) + opt_dict['lr'] = 0.001 + else: + raise ValueError("optimizer not supported") + + iter_ = decay_cnt = 0 + best_loss = 1e4 + # best_kl = best_nll = best_ppl = 0 + # pre_mi = 0 + discriminator.train() + start = time.time() + + train_data_batch, train_labels_batch = train_data.create_data_batch_labels(batch_size=args.batch_size, + device=device, + batch_first=True) + + val_data_batch, val_labels_batch = val_data.create_data_batch_labels(batch_size=128, + device=device, + batch_first=True) + + test_data_batch, test_labels_batch = test_data.create_data_batch_labels(batch_size=128, + device=device, + batch_first=True) + # + def learn_feature(data_batch,labels_batch): + feature = [] + label = [] + for i in np.random.permutation(len(data_batch)): + batch_data = data_batch[i] + batch_labels = labels_batch[i] + batch_data = batch_data.to(device) + batch_size = batch_data.size(0) + if args.IAF: + loc, zT = vae.encoder.learn_feature(batch_data) + # mu = torch.cat([loc, zT], dim=-1) + mu=zT + mu = mu.squeeze(1) + else: + mu, logvar = vae.encoder(batch_data) + feature.append(mu.detach()) + label.append(batch_labels) + return feature,label + + train_data_batch, train_labels_batch = learn_feature(train_data_batch, train_labels_batch) + val_data_batch, val_labels_batch = learn_feature(val_data_batch, val_labels_batch) + test_data_batch,test_labels_batch = learn_feature(test_data_batch,test_labels_batch) + + acc_cnt = 1 + acc_loss = 0. + for epoch in range(args.epochs): + report_loss = 0 + report_correct = report_num_sents = 0 + acc_batch_size = 0 + optimizer.zero_grad() + for i in np.random.permutation(len(train_data_batch)): + batch_data = train_data_batch[i] + if batch_data.size(0) < 2: + continue + batch_labels = train_labels_batch[i] + batch_labels = [int(x) for x in batch_labels] + + batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) + + batch_size = batch_data.size(0) + + # not predict start symbol + report_num_sents += batch_size + acc_batch_size += batch_size + + # (batch_size) + loss, correct = discriminator.get_performance_with_feature(batch_data, batch_labels) + + acc_loss = acc_loss + loss.sum() + + if acc_cnt % args.update_every == 0: + acc_loss = acc_loss / acc_batch_size + acc_loss.backward() + + torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad) + + optimizer.step() + optimizer.zero_grad() + + acc_cnt = 0 + acc_loss = 0 + acc_batch_size = 0 + + acc_cnt += 1 + report_loss += loss.sum().item() + report_correct += correct + + if iter_ % log_niter == 0: + train_loss = report_loss / report_num_sents + + iter_ += 1 + + # logging('lr {}'.format(opt_dict["lr"])) + # print(report_num_sents) + discriminator.eval() + + with torch.no_grad(): + loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args) + # print(au_var) + + if loss < best_loss: + logging('update best loss') + best_loss = loss + best_acc = acc + print(args.save_path) + torch.save(discriminator.state_dict(), args.save_path) + + if loss > opt_dict["best_loss"]: + opt_dict["not_improved"] += 1 + if opt_dict["not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: + opt_dict["best_loss"] = loss + opt_dict["not_improved"] = 0 + opt_dict["lr"] = opt_dict["lr"] * lr_decay + discriminator.load_state_dict(torch.load(args.save_path)) + logging('new lr: %f' % opt_dict["lr"]) + decay_cnt += 1 + if args.opt == "sgd": + optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum) + opt_dict['lr'] = opt_dict["lr"] + elif args.opt == "adam": + optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"]) + opt_dict['lr'] = opt_dict["lr"] + else: + raise ValueError("optimizer not supported") + + else: + opt_dict["not_improved"] = 0 + opt_dict["best_loss"] = loss + + if decay_cnt == max_decay: + break + + if epoch % args.test_nepoch == 0: + with torch.no_grad(): + loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) + + discriminator.train() + + # compute importance weighted estimate of log p(x) + discriminator.load_state_dict(torch.load(args.save_path)) + discriminator.eval() + + with torch.no_grad(): + loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) + # print(au_var) + + +if __name__ == '__main__': + args = init_config() + main(args) diff --git a/utils.py b/utils.py new file mode 100755 index 0000000..2735134 --- /dev/null +++ b/utils.py @@ -0,0 +1,340 @@ +import numpy as np +import os, sys +import torch +from torch import nn, optim +import subprocess +from modules import VariationalFlow, FlowResNetEncoderV2 +# from image_modules import GaussianEncoder +class uniform_initializer(object): + def __init__(self, stdv): + self.stdv = stdv + + def __call__(self, tensor): + nn.init.uniform_(tensor, -self.stdv, self.stdv) + + +class xavier_normal_initializer(object): + def __call__(self, tensor): + nn.init.xavier_normal_(tensor) + + + +def calc_iwnll(model, test_data_batch, args, ns=100): + report_nll_loss = 0 + report_num_words = report_num_sents = 0 + print("iw nll computing ", end="") + for id_, i in enumerate(np.random.permutation(len(test_data_batch))): + batch_data = test_data_batch[i] + batch_size, sent_len = batch_data.size() + + # not predict start symbol + report_num_words += (sent_len - 1) * batch_size + + report_num_sents += batch_size + if id_ % (round(len(test_data_batch) / 20)) == 0: + print('%d%% ' % (id_ / (round(len(test_data_batch) / 20)) * 5), end="") + sys.stdout.flush() + + loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples, ns=ns) + + report_nll_loss += loss.sum().item() + + print() + sys.stdout.flush() + + nll = report_nll_loss / report_num_sents + ppl = np.exp(nll * report_num_sents / report_num_words) + + return nll, ppl + + +def calc_mi(model, test_data_batch, device='cpu'): + # calc_mi_v3 + import math + from modules.utils import log_sum_exp + + mi = 0 + num_examples = 0 + + mu_batch_list, logvar_batch_list = [], [] + neg_entropy = 0. + for batch_data in test_data_batch: + + if isinstance(batch_data, list): + batch_data = batch_data[0] + + batch_data = batch_data.to(device) + + if isinstance(model.encoder, VariationalFlow) \ + or isinstance(model.encoder, FlowResNetEncoderV2): + # or isinstance(model.encoder, GaussianEncoder): + mu, logvar = model.encode_stats(batch_data) + else: + mu, logvar = model.encoder.forward(batch_data) + x_batch, nz = mu.size() + ##print(x_batch, end=' ') + num_examples += x_batch + + # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) + neg_entropy += (-0.5 * nz * math.log(2 * math.pi) - 0.5 * (1 + logvar).sum(-1)).sum().item() + mu_batch_list += [mu.cpu()] + logvar_batch_list += [logvar.cpu()] + + neg_entropy = neg_entropy / num_examples + + num_examples = 0 + log_qz = 0. + for i in range(len(mu_batch_list)): + ############### + # get z_samples + ############### + mu, logvar = mu_batch_list[i].to(device), logvar_batch_list[i].to(device) + + # [z_batch, 1, nz] + if hasattr(model.encoder, 'reparameterize'): + z_samples = model.encoder.reparameterize(mu, logvar, 1) + else: + z_samples = model.encoder.gaussian_enc.reparameterize(mu, logvar, 1) + z_samples = z_samples.view(-1, 1, nz) + num_examples += z_samples.size(0) + + ############### + # compute density + ############### + # [1, x_batch, nz] + + indices = np.arange(len(mu_batch_list)) + mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).to(device) + logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).to(device) + x_batch, nz = mu.size() + + mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) + var = logvar.exp() + + # (z_batch, x_batch, nz) + dev = z_samples - mu + + # (z_batch, x_batch)89 + log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ + 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) + + # log q(z): aggregate posterior + # [z_batch] + log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1) + + log_qz /= num_examples + mi = neg_entropy - log_qz + + return mi + + +def calc_au(model, test_data_batch, delta=0.01): + """compute the number of active units + """ + cnt = 0 + for batch_data in test_data_batch: + mean, _ = model.encode_stats(batch_data) + if cnt == 0: + means_sum = mean.sum(dim=0, keepdim=True) + else: + means_sum = means_sum + mean.sum(dim=0, keepdim=True) + cnt += mean.size(0) + + # (1, nz) + mean_mean = means_sum / cnt + + cnt = 0 + for batch_data in test_data_batch: + mean, _ = model.encode_stats(batch_data) + if cnt == 0: + var_sum = ((mean - mean_mean) ** 2).sum(dim=0) + else: + var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0) + cnt += mean.size(0) + + # (nz) + au_var = var_sum / (cnt - 1) + + return (au_var >= delta).sum().item(), au_var + + +def sample_sentences(vae, vocab, device, num_sentences): + global logging + + vae.eval() + sampled_sents = [] + for i in range(num_sentences): + z = vae.sample_from_prior(1) + z = z.view(1, 1, -1) + start = vocab.word2id[''] + # START = torch.tensor([[[start]]]) + START = torch.tensor([[start]]) + end = vocab.word2id[''] + START = START.to(device) + z = z.to(device) + vae.eval() + sentence = vae.decoder.sample_text(START, z, end, device) + decoded_sentence = vocab.decode_sentence(sentence) + sampled_sents.append(decoded_sentence) + for i, sent in enumerate(sampled_sents): + logging(i, ":", ' '.join(sent)) + + +def visualize_latent(args, epoch, vae, device, test_data): + nsamples = 1 + + with open(os.path.join(args.exp_dir, f'synthetic_latent_{epoch}.txt'), 'w') as f: + test_data_batch, test_label_batch = test_data.create_data_batch_labels(batch_size=args.batch_size, + device=device, batch_first=True) + for i in range(len(test_data_batch)): + batch_data = test_data_batch[i] + batch_label = test_label_batch[i] + batch_size, sent_len = batch_data.size() + samples, _ = vae.encoder.encode(batch_data, nsamples) + for i in range(batch_size): + for j in range(nsamples): + sample = samples[i, j, :].cpu().detach().numpy().tolist() + f.write(batch_label[i] + '\t' + ' '.join([str(val) for val in sample]) + '\n') + + +cnames = { +'aliceblue': '#F0F8FF', +'antiquewhite': '#FAEBD7', +'aqua': '#00FFFF', +'aquamarine': '#7FFFD4', +'azure': '#F0FFFF', +'beige': '#F5F5DC', +'bisque': '#FFE4C4', +'black': '#000000', +'blanchedalmond': '#FFEBCD', +'blue': '#0000FF', +'blueviolet': '#8A2BE2', +'brown': '#A52A2A', +'burlywood': '#DEB887', +'cadetblue': '#5F9EA0', +'chartreuse': '#7FFF00', +'chocolate': '#D2691E', +'coral': '#FF7F50', +'cornflowerblue': '#6495ED', +'cornsilk': '#FFF8DC', +'crimson': '#DC143C', +'cyan': '#00FFFF', +'darkblue': '#00008B', +'darkcyan': '#008B8B', +'darkgoldenrod': '#B8860B', +'darkgray': '#A9A9A9', +'darkgreen': '#006400', +'darkkhaki': '#BDB76B', +'darkmagenta': '#8B008B', +'darkolivegreen': '#556B2F', +'darkorange': '#FF8C00', +'darkorchid': '#9932CC', +'darkred': '#8B0000', +'darksalmon': '#E9967A', +'darkseagreen': '#8FBC8F', +'darkslateblue': '#483D8B', +'darkslategray': '#2F4F4F', +'darkturquoise': '#00CED1', +'darkviolet': '#9400D3', +'deeppink': '#FF1493', +'deepskyblue': '#00BFFF', +'dimgray': '#696969', +'dodgerblue': '#1E90FF', +'firebrick': '#B22222', +'floralwhite': '#FFFAF0', +'forestgreen': '#228B22', +'fuchsia': '#FF00FF', +'gainsboro': '#DCDCDC', +'ghostwhite': '#F8F8FF', +'gold': '#FFD700', +'goldenrod': '#DAA520', +'gray': '#808080', +'green': '#008000', +'greenyellow': '#ADFF2F', +'honeydew': '#F0FFF0', +'hotpink': '#FF69B4', +'indianred': '#CD5C5C', +'indigo': '#4B0082', +'ivory': '#FFFFF0', +'khaki': '#F0E68C', +'lavender': '#E6E6FA', +'lavenderblush': '#FFF0F5', +'lawngreen': '#7CFC00', +'lemonchiffon': '#FFFACD', +'lightblue': '#ADD8E6', +'lightcoral': '#F08080', +'lightcyan': '#E0FFFF', +'lightgoldenrodyellow': '#FAFAD2', +'lightgreen': '#90EE90', +'lightgray': '#D3D3D3', +'lightpink': '#FFB6C1', +'lightsalmon': '#FFA07A', +'lightseagreen': '#20B2AA', +'lightskyblue': '#87CEFA', +'lightslategray': '#778899', +'lightsteelblue': '#B0C4DE', +'lightyellow': '#FFFFE0', +'lime': '#00FF00', +'limegreen': '#32CD32', +'linen': '#FAF0E6', +'magenta': '#FF00FF', +'maroon': '#800000', +'mediumaquamarine': '#66CDAA', +'mediumblue': '#0000CD', +'mediumorchid': '#BA55D3', +'mediumpurple': '#9370DB', +'mediumseagreen': '#3CB371', +'mediumslateblue': '#7B68EE', +'mediumspringgreen': '#00FA9A', +'mediumturquoise': '#48D1CC', +'mediumvioletred': '#C71585', +'midnightblue': '#191970', +'mintcream': '#F5FFFA', +'mistyrose': '#FFE4E1', +'moccasin': '#FFE4B5', +'navajowhite': '#FFDEAD', +'navy': '#000080', +'oldlace': '#FDF5E6', +'olive': '#808000', +'olivedrab': '#6B8E23', +'orange': '#FFA500', +'orangered': '#FF4500', +'orchid': '#DA70D6', +'palegoldenrod': '#EEE8AA', +'palegreen': '#98FB98', +'paleturquoise': '#AFEEEE', +'palevioletred': '#DB7093', +'papayawhip': '#FFEFD5', +'peachpuff': '#FFDAB9', +'peru': '#CD853F', +'pink': '#FFC0CB', +'plum': '#DDA0DD', +'powderblue': '#B0E0E6', +'purple': '#800080', +'red': '#FF0000', +'rosybrown': '#BC8F8F', +'royalblue': '#4169E1', +'saddlebrown': '#8B4513', +'salmon': '#FA8072', +'sandybrown': '#FAA460', +'seagreen': '#2E8B57', +'seashell': '#FFF5EE', +'sienna': '#A0522D', +'silver': '#C0C0C0', +'skyblue': '#87CEEB', +'slateblue': '#6A5ACD', +'slategray': '#708090', +'snow': '#FFFAFA', +'springgreen': '#00FF7F', +'steelblue': '#4682B4', +'tan': '#D2B48C', +'teal': '#008080', +'thistle': '#D8BFD8', +'tomato': '#FF6347', +'turquoise': '#40E0D0', +'violet': '#EE82EE', +'wheat': '#F5DEB3', +'white': '#FFFFFF', +'whitesmoke': '#F5F5F5', +'yellow': '#FFFF00', +'yellowgreen': '#9ACD32'} \ No newline at end of file