diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000..c3f17c4
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitignore b/.gitignore
index b6e4761..5f8d199 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,8 @@ __pycache__/
*.so
# Distribution / packaging
+data/
+.idea/
.Python
build/
develop-eggs/
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..73f69e0
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/.idea/DU-VAE.iml b/.idea/DU-VAE.iml
new file mode 100644
index 0000000..8b8c395
--- /dev/null
+++ b/.idea/DU-VAE.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..cef7982
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,68 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..3999087
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..03505bc
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100755
index 0000000..d456b34
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2018 Junxian He
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/config/.DS_Store b/config/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/config/.DS_Store differ
diff --git a/config/config_omniglot.py b/config/config_omniglot.py
new file mode 100755
index 0000000..6baf81f
--- /dev/null
+++ b/config/config_omniglot.py
@@ -0,0 +1,12 @@
+params={
+ 'img_size': [1,28,28],
+ 'nz': 32,
+ 'enc_layers': [64, 64, 64],
+ 'dec_kernel_size': [7, 7, 7, 7, 7, 5, 5, 5, 5, 3, 3, 3, 3],
+ 'dec_layers': [32,32,32,32,32,32,32,32,32,32,32,32],
+ 'latent_feature_map': 4,
+ 'batch_size': 50,
+ 'epochs': 500,
+ 'test_nepoch': 5,
+ 'data_file': 'data/omniglot_data/omniglot.pt'
+}
diff --git a/config/config_omniglot_ss.py b/config/config_omniglot_ss.py
new file mode 100644
index 0000000..cdc448e
--- /dev/null
+++ b/config/config_omniglot_ss.py
@@ -0,0 +1,7 @@
+params={
+ 'img_size': [1,28,28],
+ 'nz': 32,
+ 'latent_feature_map': 4,
+ 'test_nepoch': 5,
+ 'root': 'data/omniglot_data/'
+}
\ No newline at end of file
diff --git a/config/config_short_yelp.py b/config/config_short_yelp.py
new file mode 100755
index 0000000..236fbd6
--- /dev/null
+++ b/config/config_short_yelp.py
@@ -0,0 +1,155 @@
+params={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 128,
+ 'enc_nh': 512,
+ 'dec_nh': 512,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ 'batch_size': 32,
+ 'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/short_yelp_data/short_yelp.train.txt',
+ 'val_data': 'data/short_yelp_data/short_yelp.valid.txt',
+ 'test_data': 'data/short_yelp_data/short_yelp.test.txt',
+ 'vocab_file': 'data/short_yelp_data/vocab.txt',
+ "label": True
+}
+
+
+params_ss_10={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 128,
+ 'enc_nh': 512,
+ 'dec_nh': 512,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/short_yelp_data/short_yelp.train.10.txt',
+ 'val_data': 'data/short_yelp_data/short_yelp.valid.txt',
+ 'test_data': 'data/short_yelp_data/short_yelp.test.txt',
+ # 'vocab_file': 'data/short_yelp_data/vocab.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 10,
+ "label": True
+}
+
+
+params_ss_100={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 128,
+ 'enc_nh': 512,
+ 'dec_nh': 512,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/short_yelp_data/short_yelp.train.100.txt',
+ 'val_data': 'data/short_yelp_data/short_yelp.valid.txt',
+ 'test_data': 'data/short_yelp_data/short_yelp.test.txt',
+ # 'vocab_file': 'data/short_yelp_data/vocab.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 10,
+ "label": True
+}
+
+params_ss_500={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 128,
+ 'enc_nh': 512,
+ 'dec_nh': 512,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ #'batch_size': 50,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/short_yelp_data/short_yelp.train.500.txt',
+ 'val_data': 'data/short_yelp_data/short_yelp.valid.txt',
+ 'test_data': 'data/short_yelp_data/short_yelp.test.txt',
+ # 'vocab_file': 'data/short_yelp_data/vocab.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 10,
+ "label": True
+}
+
+params_ss_1000={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 128,
+ 'enc_nh': 512,
+ 'dec_nh': 512,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/short_yelp_data/short_yelp.train.1000.txt',
+ 'val_data': 'data/short_yelp_data/short_yelp.valid.txt',
+ 'test_data': 'data/short_yelp_data/short_yelp.test.txt',
+ # 'vocab_file': 'data/short_yelp_data/vocab.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 10,
+ "label": True
+}
+
+
+params_ss_2000={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 128,
+ 'enc_nh': 512,
+ 'dec_nh': 512,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/short_yelp_data/short_yelp.train.2000.txt',
+ 'val_data': 'data/short_yelp_data/short_yelp.valid.txt',
+ 'test_data': 'data/short_yelp_data/short_yelp.test.txt',
+ # 'vocab_file': 'data/short_yelp_data/vocab.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 10,
+ "label": True
+}
+
+
+params_ss_10000={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 128,
+ 'enc_nh': 512,
+ 'dec_nh': 512,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/short_yelp_data/short_yelp.train.10000.txt',
+ 'val_data': 'data/short_yelp_data/short_yelp.valid.txt',
+ 'test_data': 'data/short_yelp_data/short_yelp.test.txt',
+ # 'vocab_file': 'data/short_yelp_data/vocab.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 10,
+ "label": True
+}
diff --git a/config/config_synthetic.py b/config/config_synthetic.py
new file mode 100755
index 0000000..31649bf
--- /dev/null
+++ b/config/config_synthetic.py
@@ -0,0 +1,19 @@
+
+params={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 2,
+ 'ni': 50,
+ 'enc_nh': 50,
+ 'dec_nh': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ 'epochs': 50,
+ 'batch_size': 32,
+ 'test_nepoch': 1,
+ 'train_data': 'data/synthetic_data/synthetic_train.txt',
+ 'val_data': 'data/synthetic_data/synthetic_test.txt',
+ 'test_data': 'data/synthetic_data/synthetic_test.txt',
+ 'vocab_file': 'data/synthetic_data/vocab.txt',
+ "label": True
+}
diff --git a/config/config_yahoo.py b/config/config_yahoo.py
new file mode 100755
index 0000000..48c8f0b
--- /dev/null
+++ b/config/config_yahoo.py
@@ -0,0 +1,18 @@
+
+params={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ 'batch_size': 32,
+ 'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yahoo_data/yahoo.train.txt',
+ 'val_data': 'data/yahoo_data/yahoo.valid.txt',
+ 'test_data': 'data/yahoo_data/yahoo.test.txt',
+ 'vocab_file': 'data/yahoo_data/vocab.txt'
+}
diff --git a/config/config_yelp.py b/config/config_yelp.py
new file mode 100755
index 0000000..f9d0d64
--- /dev/null
+++ b/config/config_yelp.py
@@ -0,0 +1,149 @@
+
+params={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ 'batch_size': 32,
+ 'epochs': 150,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yelp_data/yelp.train.txt',
+ 'val_data': 'data/yelp_data/yelp.valid.txt',
+ 'test_data': 'data/yelp_data/yelp.test.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'label':True
+}
+
+
+params_ss_10={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yelp_data/yelp.train.10.txt',
+ 'val_data': 'data/yelp_data/yelp.valid.txt',
+ 'test_data': 'data/yelp_data/yelp.test.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 5,
+ "label": True
+}
+
+
+params_ss_100={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yelp_data/yelp.train.100.txt',
+ 'val_data': 'data/yelp_data/yelp.valid.txt',
+ 'test_data': 'data/yelp_data/yelp.test.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 5,
+ "label": True
+}
+
+params_ss_500={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ #'batch_size': 50,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yelp_data/yelp.train.500.txt',
+ 'val_data': 'data/yelp_data/yelp.valid.txt',
+ 'test_data': 'data/yelp_data/yelp.test.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 5,
+ "label": True
+}
+
+params_ss_1000={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yelp_data/yelp.train.1000.txt',
+ 'val_data': 'data/yelp_data/yelp.valid.txt',
+ 'test_data': 'data/yelp_data/yelp.test.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 5,
+ "label": True
+}
+
+
+params_ss_2000={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yelp_data/yelp.train.2000.txt',
+ 'val_data': 'data/yelp_data/yelp.valid.txt',
+ 'test_data': 'data/yelp_data/yelp.test.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 5,
+ "label": True
+}
+
+
+params_ss_10000={
+ 'enc_type': 'lstm',
+ 'dec_type': 'lstm',
+ 'nz': 32,
+ 'ni': 512,
+ 'enc_nh': 1024,
+ 'dec_nh': 1024,
+ 'log_niter': 50,
+ 'dec_dropout_in': 0.5,
+ 'dec_dropout_out': 0.5,
+ # 'batch_size': 32,
+ #'epochs': 100,
+ 'test_nepoch': 5,
+ 'train_data': 'data/yelp_data/yelp.train.10000.txt',
+ 'val_data': 'data/yelp_data/yelp.valid.txt',
+ 'test_data': 'data/yelp_data/yelp.test.txt',
+ 'vocab_file': 'data/yelp_data/vocab.txt',
+ 'ncluster': 5,
+ "label": True
+}
\ No newline at end of file
diff --git a/exp_utils.py b/exp_utils.py
new file mode 100755
index 0000000..6f55642
--- /dev/null
+++ b/exp_utils.py
@@ -0,0 +1,43 @@
+import functools
+import os, shutil
+
+import numpy as np
+
+import torch
+
+
+def logging(s, log_path, print_=True, log_=True):
+ print(s)
+ # if print_:
+ # print(s)
+ if log_:
+ with open(log_path, 'a+') as f_log:
+ f_log.write(s + '\n')
+
+def get_logger(log_path, **kwargs):
+ return functools.partial(logging, log_path=log_path, **kwargs)
+
+def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
+ if debug:
+ print('Debug Mode : no experiment dir created')
+ return functools.partial(logging, log_path=None, log_=False)
+
+ if os.path.exists(dir_path):
+ print("Path {} exists. Remove and remake.".format(dir_path))
+ shutil.rmtree(dir_path)
+
+ os.makedirs(dir_path)
+
+ print('Experiment dir : {}'.format(dir_path))
+ if scripts_to_save is not None:
+ script_path = os.path.join(dir_path, 'scripts')
+ if not os.path.exists(script_path):
+ os.makedirs(script_path)
+ for script in scripts_to_save:
+ dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
+ shutil.copyfile(script, dst_file)
+
+ return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
+
+def save_checkpoint(model, optimizer, path, epoch):
+ torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
\ No newline at end of file
diff --git a/image.py b/image.py
new file mode 100755
index 0000000..0b20da2
--- /dev/null
+++ b/image.py
@@ -0,0 +1,406 @@
+import sys
+import os
+import time
+import importlib
+import argparse
+
+import numpy as np
+
+import torch
+import torch.utils.data
+from torch import optim
+
+from modules import ResNetEncoderV2, BNResNetEncoderV2, PixelCNNDecoderV2
+from modules import VAE
+from logger import Logger
+from utils import calc_mi
+
+clip_grad = 5.0
+decay_epoch = 20
+lr_decay = 0.5
+max_decay = 5
+
+
+def init_config():
+ parser = argparse.ArgumentParser(description='VAE mode collapse study')
+
+ # model hyperparameters
+ parser.add_argument('--dataset', default='omniglot', type=str, help='dataset to use')
+
+ # optimization parameters
+ parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training')
+ parser.add_argument('--iw_nsamples', type=int, default=500,
+ help='number of samples to compute importance weighted estimate')
+ # select mode
+ parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll')
+ parser.add_argument('--load_path', type=str, default='')
+ # annealing paramters
+ parser.add_argument('--warm_up', type=int, default=10)
+ parser.add_argument('--kl_start', type=float, default=1.0)
+
+ # these are for slurm purpose to save model
+ parser.add_argument('--jobid', type=int, default=0, help='slurm job id')
+ parser.add_argument('--taskid', type=int, default=0, help='slurm task id')
+ parser.add_argument('--device', type=str, default="cpu")
+ parser.add_argument('--delta_rate', type=float, default=1.0,
+ help=" coontrol the minization of the variation of latent variables")
+ parser.add_argument('--gamma', type=float, default=0.5) # BN-VAE
+ parser.add_argument("--reset_dec", action="store_true", default=False)
+ parser.add_argument("--nz_new", type=int, default=32) # myGaussianLSTMencoder
+ parser.add_argument('--p_drop', type=float, default=0.2) # p \in [0, 1]
+
+ args = parser.parse_args()
+ if 'cuda' in args.device:
+ args.cuda = True
+ else:
+ args.cuda = False
+
+ load_str = "_load" if args.load_path != "" else ""
+ save_dir = "models/%s%s/" % (args.dataset, load_str)
+
+
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ cw_str = '_warm%d' % args.warm_up
+ else:
+ cw_str = ''
+
+ hkl_str = 'KL%.2f' % args.kl_start
+ drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else ''
+
+ seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909]
+ args.seed = seed_set[args.taskid]
+
+ if args.gamma > 0:
+ gamma_str = '_gamma%.2f' % (args.gamma)
+ else:
+ gamma_str = ''
+
+ id_ = "%s_%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d" % \
+ (args.dataset, hkl_str,
+ cw_str, load_str, gamma_str, args.delta_rate,
+ args.nz_new,drop_str,
+ args.jobid, args.taskid, args.seed)
+
+ save_dir += id_
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ save_path = os.path.join(save_dir, 'model.pt')
+
+ args.save_path = save_path
+ print("save path", args.save_path)
+
+ args.log_path = os.path.join(save_dir, "log.txt")
+ print("log path", args.log_path)
+
+ # load config file into args
+ config_file = "config.config_%s" % args.dataset
+ params = importlib.import_module(config_file).params
+ args = argparse.Namespace(**vars(args), **params)
+ if args.nz != args.nz_new:
+ args.nz = args.nz_new
+ print('args.nz', args.nz)
+
+ if 'label' in params:
+ args.label = params['label']
+ else:
+ args.label = False
+
+ args.kl_weight = 1
+
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if args.cuda:
+ torch.cuda.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+ return args
+
+
+def test(model, test_loader, mode, args):
+ report_kl_loss = report_kl_t_loss = report_rec_loss = 0
+ report_num_examples = 0
+ mutual_info = []
+ for datum in test_loader:
+ batch_data, _ = datum
+ batch_data = batch_data.to(args.device)
+
+ batch_size = batch_data.size(0)
+
+ report_num_examples += batch_size
+
+ loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False)
+ loss_kl_t = model.KL(batch_data, args)
+
+ assert (not loss_rc.requires_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+ loss_kl_t = loss_kl_t.sum()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+ report_kl_t_loss += loss_kl_t.item()
+
+ mutual_info = calc_mi(model, test_loader, device=args.device)
+
+ test_loss = (report_rec_loss + report_kl_loss) / report_num_examples
+
+ nll = (report_kl_t_loss + report_rec_loss) / report_num_examples
+ kl = report_kl_loss / report_num_examples
+ kl_t = report_kl_t_loss / report_num_examples
+
+ print('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f' % \
+ (mode, test_loss, report_kl_t_loss / report_num_examples, mutual_info,
+ report_rec_loss / report_num_examples, nll))
+ sys.stdout.flush()
+
+ return test_loss, nll, kl_t ##返回真实的kl_t 不是训练中的kl
+
+
+def calc_au(model, test_loader, delta=0.01):
+ """compute the number of active units
+ """
+ means = []
+ for datum in test_loader:
+ batch_data, _ = datum
+
+ batch_data = batch_data.to(args.device)
+
+ mean, _ = model.encode_stats(batch_data)
+ means.append(mean)
+
+ means = torch.cat(means, dim=0)
+ au_mean = means.mean(0, keepdim=True)
+
+ # (batch_size, nz)
+ au_var = means - au_mean
+ ns = au_var.size(0)
+
+ au_var = (au_var ** 2).sum(dim=0) / (ns - 1)
+
+ return (au_var >= delta).sum().item(), au_var
+
+
+def calc_iwnll(model, test_loader, args):
+ report_nll_loss = 0
+ report_num_examples = 0
+ for id_, datum in enumerate(test_loader):
+ batch_data, _ = datum
+ batch_data = batch_data.to(args.device)
+
+ batch_size = batch_data.size(0)
+
+ report_num_examples += batch_size
+
+ if id_ % (round(len(test_loader) / 10)) == 0:
+ print('iw nll computing %d0%%' % (id_ / (round(len(test_loader) / 10))))
+ sys.stdout.flush()
+
+ loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples)
+
+ report_nll_loss += loss.sum().item()
+
+ nll = report_nll_loss / report_num_examples
+
+ print('iw nll: %.4f' % nll)
+ sys.stdout.flush()
+ return nll
+
+
+def main(args):
+ if args.cuda:
+ print('using cuda')
+ print(args)
+
+ args.device = torch.device(args.device)
+ device = args.device
+
+ opt_dict = {"not_improved": 0, "lr": 0.001, "best_loss": 1e4}
+
+ all_data = torch.load(args.data_file)
+ x_train, x_val, x_test = all_data
+ if args.dataset == 'omniglot':
+
+ x_train = x_train.to(device)
+ x_val = x_val.to(device)
+ x_test = x_test.to(device)
+ y_size = 1
+ y_train = x_train.new_zeros(x_train.size(0), y_size)
+ y_val = x_train.new_zeros(x_val.size(0), y_size)
+ y_test = x_train.new_zeros(x_test.size(0), y_size)
+
+ print(torch.__version__)
+ train_data = torch.utils.data.TensorDataset(x_train, y_train)
+ val_data = torch.utils.data.TensorDataset(x_val, y_val)
+ test_data = torch.utils.data.TensorDataset(x_test, y_test)
+
+
+ train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
+ val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=True)
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
+ print('Train data: %d batches' % len(train_loader))
+ print('Val data: %d batches' % len(val_loader))
+ print('Test data: %d batches' % len(test_loader))
+ sys.stdout.flush()
+
+ log_niter = len(train_loader) // 5
+
+ if args.gamma > 0:
+ encoder = BNResNetEncoderV2(args)
+ else:
+ encoder = ResNetEncoderV2(args)
+
+ decoder = PixelCNNDecoderV2(args)
+
+ vae = VAE(encoder, decoder, args).to(device)
+
+ if args.eval:
+ print('begin evaluation')
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True)
+ vae.load_state_dict(torch.load(args.load_path))
+ vae.eval()
+ with torch.no_grad():
+ test(vae, test_loader, "TEST", args)
+ au, au_var = calc_au(vae, test_loader)
+ print("%d active units" % au)
+ # print(au_var)
+ calc_iwnll(vae, test_loader, args)
+ return
+
+ enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001)
+ dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001)
+ opt_dict['lr'] = 0.001
+
+ iter_ = 0
+ best_loss = 1e4
+ decay_cnt = 0
+ vae.train()
+ start = time.time()
+
+ kl_weight = args.kl_start
+ anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader))
+
+ for epoch in range(args.epochs):
+
+ report_kl_loss = report_rec_loss = 0
+ report_num_examples = 0
+ for datum in train_loader:
+ batch_data, _ = datum
+ batch_data = batch_data.to(device)
+ if args.dataset != 'fashion-mnist':
+ batch_data = torch.bernoulli(batch_data)
+ batch_size = batch_data.size(0)
+
+ report_num_examples += batch_size
+
+ # kl_weight = 1.0
+
+ kl_weight = min(1.0, kl_weight + anneal_rate)
+ args.kl_weight = kl_weight
+
+ enc_optimizer.zero_grad()
+ dec_optimizer.zero_grad()
+
+ loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args)
+
+ loss = loss.mean(dim=-1)
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+
+ enc_optimizer.step()
+ dec_optimizer.step()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+
+ if iter_ % log_niter == 0:
+
+ train_loss = (report_rec_loss + report_kl_loss) / report_num_examples
+ if epoch == 0:
+ vae.eval()
+ with torch.no_grad():
+ mi = calc_mi(vae, val_loader, device=device)
+ au, _ = calc_au(vae, val_loader)
+
+ vae.train()
+
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
+ 'au %d, time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi,
+ report_rec_loss / report_num_examples, au, time.time() - start))
+ else:
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
+ 'time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_examples,
+ report_rec_loss / report_num_examples, time.time() - start))
+ sys.stdout.flush()
+
+ report_rec_loss = report_kl_loss = 0
+ report_num_examples = 0
+
+ iter_ += 1
+
+ print('kl weight %.4f' % args.kl_weight)
+ print('epoch: %d, VAL' % epoch)
+
+ vae.eval()
+
+ with torch.no_grad():
+ loss, nll, kl = test(vae, val_loader, "VAL", args)
+ au, au_var = calc_au(vae, val_loader)
+ print("%d active units" % au)
+ # print(au_var)
+
+ if loss < best_loss:
+ print('update best loss')
+ best_loss = loss
+ torch.save(vae.state_dict(), args.save_path)
+
+ if loss > best_loss:
+ opt_dict["not_improved"] += 1
+ if opt_dict["not_improved"] >= decay_epoch:
+ opt_dict["best_loss"] = loss
+ opt_dict["not_improved"] = 0
+ opt_dict["lr"] = opt_dict["lr"] * lr_decay
+ vae.load_state_dict(torch.load(args.save_path))
+ decay_cnt += 1
+ print('new lr: %f' % opt_dict["lr"])
+ enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"])
+ dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"])
+ else:
+ opt_dict["not_improved"] = 0
+ opt_dict["best_loss"] = loss
+
+ if decay_cnt == max_decay:
+ break
+
+ if epoch % args.test_nepoch == 0:
+ with torch.no_grad():
+ loss, nll, kl = test(vae, test_loader, "TEST", args)
+
+ vae.train()
+
+ # compute importance weighted estimate of log p(x)
+ vae.load_state_dict(torch.load(args.save_path))
+ vae.eval()
+ with torch.no_grad():
+ loss, nll, kl = test(vae, test_loader, "TEST", args)
+ au, au_var = calc_au(vae, test_loader)
+ print("%d active units" % au)
+ # print(au_var)
+
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True)
+
+ with torch.no_grad():
+ calc_iwnll(vae, test_loader, args)
+
+
+if __name__ == '__main__':
+ args = init_config()
+ if not args.eval:
+ sys.stdout = Logger(args.log_path)
+ main(args)
diff --git a/image_IAF.py b/image_IAF.py
new file mode 100644
index 0000000..ea5be8f
--- /dev/null
+++ b/image_IAF.py
@@ -0,0 +1,419 @@
+import sys
+import os
+import time
+import importlib
+import argparse
+
+import numpy as np
+
+import torch
+import torch.utils.data
+# from torchvision.utils import save_image
+from torch import nn, optim
+
+from modules import FlowResNetEncoderV2, PixelCNNDecoderV2
+from modules import VAEIAF as VAE
+from logger import Logger
+from utils import calc_mi
+
+clip_grad = 5.0
+decay_epoch = 20
+lr_decay = 0.5
+max_decay = 5
+
+
+def init_config():
+ parser = argparse.ArgumentParser(description='VAE mode collapse study')
+
+ # model hyperparameters
+ parser.add_argument('--dataset', default='omniglot', type=str, help='dataset to use')
+
+ # optimization parameters
+ parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training')
+ parser.add_argument('--iw_nsamples', type=int, default=500,
+ help='number of samples to compute importance weighted estimate')
+ # select mode
+ parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll')
+ parser.add_argument('--load_path', type=str, default='')
+ # annealing paramters
+ parser.add_argument('--warm_up', type=int, default=10)
+ parser.add_argument('--kl_start', type=float, default=1.0)
+ # these are for slurm purpose to save model
+ parser.add_argument('--jobid', type=int, default=0, help='slurm job id')
+ parser.add_argument('--taskid', type=int, default=0, help='slurm task id')
+ parser.add_argument('--device', type=str, default="cpu")
+ parser.add_argument('--delta_rate', type=float, default=1.0,
+ help=" coontrol the minization of the variation of latent variables")
+ parser.add_argument('--gamma', type=float, default=0.5) # BN-VAE
+ parser.add_argument("--reset_dec", action="store_true", default=False)
+ parser.add_argument("--nz_new", type=int, default=32) # myGaussianLSTMencoder
+ parser.add_argument('--p_drop', type=float, default=0.15) # p \in [0, 1]
+
+ parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow")
+ parser.add_argument('--flow_width', type=int, default=2, help="width of flow")
+
+ parser.add_argument("--fb", type=int, default=0,
+ help="0: no fb; 1: ")
+
+ parser.add_argument("--target_kl", type=float, default=0.0,
+ help="target kl of the free bits trick")
+ parser.add_argument('--drop_start', type=float, default=1.0, help="starting KL weight")
+
+ args = parser.parse_args()
+ if 'cuda' in args.device:
+ args.cuda = True
+ else:
+ args.cuda = False
+
+ load_str = "_load" if args.load_path != "" else ""
+ save_dir = "models/%s%s/" % (args.dataset, load_str)
+
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ cw_str = '_warm%d' % args.warm_up + '_%.2f' % args.kl_start
+ else:
+ cw_str = ''
+
+ if args.fb == 0:
+ fb_str = ""
+ elif args.fb in [1, 2]:
+ fb_str = "_fb%d_tr%.2f" % (args.fb, args.target_kl)
+
+ else:
+ fb_str = ''
+
+ drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else ''
+ if 1.0 > args.drop_start > 0 and args.p_drop != 0:
+ drop_str += '_start%.2f' % args.drop_start
+
+ seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909]
+ args.seed = seed_set[args.taskid]
+
+ if args.gamma > 0:
+ gamma_str = '_gamma%.2f' % (args.gamma)
+
+ else:
+ gamma_str = ''
+
+ if args.flow_depth > 0:
+ fd_str = '_fd%d_fw%d' % (args.flow_depth, args.flow_width)
+
+ id_ = "%s%s%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d_IAF" % \
+ (args.dataset, cw_str, load_str, gamma_str, fb_str, fd_str,
+ args.delta_rate, args.nz_new, drop_str,
+ args.jobid, args.taskid, args.seed)
+ save_dir += id_
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ save_path = os.path.join(save_dir, 'model.pt')
+
+ args.save_path = save_path
+ print("save path", args.save_path)
+
+ args.log_path = os.path.join(save_dir, "log.txt")
+ print("log path", args.log_path)
+
+ # load config file into args
+ config_file = "config.config_%s" % args.dataset
+ params = importlib.import_module(config_file).params
+ args = argparse.Namespace(**vars(args), **params)
+ if args.nz != args.nz_new:
+ args.nz = args.nz_new
+ print('args.nz', args.nz)
+
+ if 'label' in params:
+ args.label = params['label']
+ else:
+ args.label = False
+
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if args.cuda:
+ torch.cuda.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+ return args
+
+
+def test(model, test_loader, mode, args):
+ report_kl_loss = report_kl_t_loss = report_rec_loss = 0
+ report_num_examples = 0
+ mutual_info = []
+ for datum in test_loader:
+ batch_data, _ = datum
+ batch_data = batch_data.to(args.device)
+
+ batch_size = batch_data.size(0)
+
+ report_num_examples += batch_size
+
+ loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False)
+ loss_kl_t = model.KL(batch_data, args)
+
+ assert (not loss_rc.requires_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+ loss_kl_t = loss_kl_t.sum()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+ report_kl_t_loss += loss_kl_t.item()
+
+ mutual_info = calc_mi(model, test_loader, device=args.device)
+
+ test_loss = (report_rec_loss + report_kl_loss) / report_num_examples
+
+ nll = (report_kl_t_loss + report_rec_loss) / report_num_examples
+ kl = report_kl_loss / report_num_examples
+ kl_t = report_kl_t_loss / report_num_examples
+
+ print('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f' % \
+ (mode, test_loss, report_kl_t_loss / report_num_examples, mutual_info,
+ report_rec_loss / report_num_examples, nll))
+ sys.stdout.flush()
+
+ return test_loss, nll, kl_t ##返回真实的kl_t 不是训练中的kl
+
+
+def calc_iwnll(model, test_loader, args):
+ report_nll_loss = 0
+ report_num_examples = 0
+ for id_, datum in enumerate(test_loader):
+ batch_data, _ = datum
+ batch_data = batch_data.to(args.device)
+
+ batch_size = batch_data.size(0)
+
+ report_num_examples += batch_size
+
+ if id_ % (round(len(test_loader) / 10)) == 0:
+ print('iw nll computing %d0%%' % (id_ / (round(len(test_loader) / 10))))
+ sys.stdout.flush()
+
+ loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples)
+
+ report_nll_loss += loss.sum().item()
+
+ nll = report_nll_loss / report_num_examples
+
+ print('iw nll: %.4f' % nll)
+ sys.stdout.flush()
+ return nll
+
+def calc_au(model, test_loader, delta=0.01):
+ """compute the number of active units
+ """
+ means = []
+ for datum in test_loader:
+ batch_data, _ = datum
+
+ batch_data = batch_data.to(args.device)
+
+ mean, _ = model.encode_stats(batch_data)
+ means.append(mean)
+
+ means = torch.cat(means, dim=0)
+ au_mean = means.mean(0, keepdim=True)
+
+ # (batch_size, nz)
+ au_var = means - au_mean
+ ns = au_var.size(0)
+
+ au_var = (au_var ** 2).sum(dim=0) / (ns - 1)
+
+ return (au_var >= delta).sum().item(), au_var
+
+def main(args):
+ if args.cuda:
+ print('using cuda')
+ print(args)
+
+ args.device = torch.device(args.device)
+ device = args.device
+
+ opt_dict = {"not_improved": 0, "lr": 0.001, "best_loss": 1e4}
+
+ all_data = torch.load(args.data_file)
+ x_train, x_val, x_test = all_data
+ if args.dataset == 'omniglot':
+ x_train = x_train.to(device)
+ x_val = x_val.to(device)
+ x_test = x_test.to(device)
+ y_size = 1
+ y_train = x_train.new_zeros(x_train.size(0), y_size)
+ y_val = x_train.new_zeros(x_val.size(0), y_size)
+ y_test = x_train.new_zeros(x_test.size(0), y_size)
+
+ print(torch.__version__)
+ train_data = torch.utils.data.TensorDataset(x_train, y_train)
+ val_data = torch.utils.data.TensorDataset(x_val, y_val)
+ test_data = torch.utils.data.TensorDataset(x_test, y_test)
+
+ train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
+ val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=True)
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
+ print('Train data: %d batches' % len(train_loader))
+ print('Val data: %d batches' % len(val_loader))
+ print('Test data: %d batches' % len(test_loader))
+ sys.stdout.flush()
+
+ log_niter = len(train_loader) // 5
+
+ encoder = FlowResNetEncoderV2(args)
+ decoder = PixelCNNDecoderV2(args)
+
+ vae = VAE(encoder, decoder, args).to(device)
+
+
+ if args.eval:
+ print('begin evaluation')
+ args.kl_weight = 1
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True)
+ vae.load_state_dict(torch.load(args.load_path))
+ vae.eval()
+ with torch.no_grad():
+ test(vae, test_loader, "TEST", args)
+ au, au_var = calc_au(vae, test_loader)
+ print("%d active units" % au)
+ # print(au_var)
+ calc_iwnll(vae, test_loader, args)
+ return
+
+ enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001)
+ dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001)
+ opt_dict['lr'] = 0.001
+
+ iter_ = 0
+ best_loss = 1e4
+ best_kl = best_nll = best_ppl = 0
+ decay_cnt = pre_mi = best_mi = mi_not_improved = 0
+ vae.train()
+ start = time.time()
+
+ kl_weight = args.kl_start
+ anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader))
+
+ for epoch in range(args.epochs):
+
+ report_kl_loss = report_rec_loss = 0
+ report_num_examples = 0
+ for datum in train_loader:
+ batch_data, _ = datum
+ batch_data = batch_data.to(device)
+ batch_data = torch.bernoulli(batch_data)
+ batch_size = batch_data.size(0)
+ report_num_examples += batch_size
+
+ kl_weight = min(1.0, kl_weight + anneal_rate)
+ args.kl_weight = kl_weight
+
+ enc_optimizer.zero_grad()
+ dec_optimizer.zero_grad()
+
+ loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args)
+
+ loss = loss.mean(dim=-1)
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+
+ enc_optimizer.step()
+ dec_optimizer.step()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+
+ if iter_ % log_niter == 0:
+
+ train_loss = (report_rec_loss + report_kl_loss) / report_num_examples
+ if epoch == 0:
+ vae.eval()
+ with torch.no_grad():
+ mi = calc_mi(vae, val_loader, device=device)
+ au, _ = calc_au(vae, val_loader)
+
+ vae.train()
+
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
+ 'au %d, time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi,
+ report_rec_loss / report_num_examples, au, time.time() - start))
+ else:
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
+ 'time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_examples,
+ report_rec_loss / report_num_examples, time.time() - start))
+ sys.stdout.flush()
+
+ report_rec_loss = report_kl_loss = 0
+ report_num_examples = 0
+
+ iter_ += 1
+
+
+
+ print('kl weight %.4f' % args.kl_weight)
+ print('epoch: %d, VAL' % epoch)
+
+ vae.eval()
+
+ with torch.no_grad():
+ loss, nll, kl = test(vae, val_loader, "VAL", args)
+ au, au_var = calc_au(vae, val_loader)
+ print("%d active units" % au)
+ # print(au_var)
+
+ if loss < best_loss:
+ print('update best loss')
+ best_loss = loss
+ best_nll = nll
+ best_kl = kl
+ torch.save(vae.state_dict(), args.save_path)
+
+ if loss > best_loss:
+ opt_dict["not_improved"] += 1
+ if opt_dict["not_improved"] >= decay_epoch:
+ opt_dict["best_loss"] = loss
+ opt_dict["not_improved"] = 0
+ opt_dict["lr"] = opt_dict["lr"] * lr_decay
+ vae.load_state_dict(torch.load(args.save_path))
+ decay_cnt += 1
+ print('new lr: %f' % opt_dict["lr"])
+ enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"])
+ dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"])
+ else:
+ opt_dict["not_improved"] = 0
+ opt_dict["best_loss"] = loss
+
+ if decay_cnt == max_decay:
+ break
+
+ if epoch % args.test_nepoch == 0:
+ with torch.no_grad():
+ loss, nll, kl = test(vae, test_loader, "TEST", args)
+
+ vae.train()
+
+ # compute importance weighted estimate of log p(x)
+ vae.load_state_dict(torch.load(args.save_path))
+ vae.eval()
+ with torch.no_grad():
+ loss, nll, kl = test(vae, test_loader, "TEST", args)
+ au, au_var = calc_au(vae, test_loader)
+ print("%d active units" % au)
+ # print(au_var)
+
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True)
+
+ with torch.no_grad():
+ calc_iwnll(vae, test_loader, args)
+
+
+if __name__ == '__main__':
+ args = init_config()
+ if not args.eval:
+ sys.stdout = Logger(args.log_path)
+ main(args)
diff --git a/image_ss_omniglot.py b/image_ss_omniglot.py
new file mode 100644
index 0000000..509e02f
--- /dev/null
+++ b/image_ss_omniglot.py
@@ -0,0 +1,311 @@
+import os
+import time
+import importlib
+import argparse
+import sys
+import numpy as np
+
+import torch
+from torch import nn, optim
+
+from modules import ResNetEncoderV2,BNResNetEncoderV2, PixelCNNDecoderV2,FlowResNetEncoderV2
+from modules import VAE, LinearDiscriminator_only
+from logger import Logger
+from omniglotDataset import Omniglot
+
+
+# Junxian's new parameters
+clip_grad = 1.0
+decay_epoch = 2
+lr_decay = 0.8
+max_decay = 5
+
+
+def init_config():
+ parser = argparse.ArgumentParser(description='VAE mode collapse study')
+ parser.add_argument('--gamma', type=float, default=0.0)
+ parser.add_argument('--gamma_type', type=str, default='BN')
+ parser.add_argument('--gamma_train', action="store_true", default=False)
+
+ # model hyperparameters
+ parser.add_argument('--delta', type=float, default=0.0)
+ parser.add_argument('--dataset', default='omniglot', type=str, help='dataset to use')
+ # optimization parameters
+ parser.add_argument('--momentum', type=float, default=0.9, help='sgd momentum')
+ parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="adam", help='sgd momentum')
+
+ parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training')
+ parser.add_argument('--iw_nsamples', type=int, default=500,
+ help='number of samples to compute importance weighted estimate')
+
+ # select mode
+ parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll')
+ parser.add_argument('--load_path', type=str,
+ default='models/mnist/test/model.pt') # TODO: 设定load_path
+
+ # annealing paramters
+ parser.add_argument('--warm_up', type=int, default=100,
+ help="number of annealing epochs. warm_up=0 means not anneal")
+ parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight")
+
+ # output directory
+ parser.add_argument('--exp_dir', default=None, type=str,
+ help='experiment directory.')
+ parser.add_argument("--save_ckpt", type=int, default=0,
+ help="save checkpoint every epoch before this number")
+ parser.add_argument("--save_latent", type=int, default=0)
+
+ # new
+ parser.add_argument("--reset_dec", action="store_true", default=True)
+ parser.add_argument("--load_best_epoch", type=int, default=0)
+ parser.add_argument("--lr", type=float, default=1.)
+
+ parser.add_argument("--fb", type=int, default=0,
+ help="0: no fb; 1: fb; E")
+ parser.add_argument("--target_kl", type=float, default=-1,
+ help="target kl of the free bits trick")
+
+ parser.add_argument("--batch_size", type=int, default=50,
+ help="number of epochs")
+ parser.add_argument("--epochs", type=int, default=300,
+ help="number of epochs")
+ parser.add_argument("--num_label", type=int, default=100,
+ help="t")
+ parser.add_argument("--freeze_enc", action="store_true",
+ default=True) # True-> freeze the parameters of vae.encoder
+ parser.add_argument("--discriminator", type=str, default="linear")
+
+ parser.add_argument('--taskid', type=int, default=0, help='slurm task id')
+ parser.add_argument('--device', type=str, default="cpu")
+ parser.add_argument('--delta_rate', type=float, default=0.0,
+ help=" coontrol the minization of the variation of latent variables")
+
+ parser.add_argument("--nz_new", type=int, default=32) # myGaussianLSTMencoder
+
+ parser.add_argument('--IAF', action='store_true', default=False)
+ parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow")
+ parser.add_argument('--flow_width', type=int, default=60, help="width of flow")
+ parser.add_argument('--p_drop', type=float, default=0) # p \in [0, 1]
+
+ args = parser.parse_args()
+
+ # args.load_path ='models/omniglot/omniglot_aggressive1_KL1.00_dr0.00_beta-1.00_nz32_0_0_783435_betaF_4/model.pt'
+ # args.load_path ='models/omniglot/omniglot_aggressive0_KL0.00_warm10_gamma0.50_BN_train5_dr1.00_beta-1.00_nz32_drop0.15_0_0_783435_betaF_5_large_de20/model.pt'
+ # args.gamma = 0.5
+ # args.gamma_type = 'BN'
+ # args.load_path = 'models/omniglot/omniglot_fb2_tr0.20_fd2_fw60_dr0.00_nz32_0_0_783435_IAF/model.pt'
+ # args.IAF = True
+
+ # if len(args.load_path)>0:
+ # args.load_path = 'models/'+args.dataset+'/'+args.load_path
+
+ # set args.cuda
+ if 'cuda' in args.device:
+ args.cuda = True
+ else:
+ args.cuda = False
+
+ # set seeds
+ seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909]
+ args.seed = seed_set[args.taskid]
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+
+ if args.cuda:
+ torch.cuda.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+
+ config_file = "config.config_%s_ss" % args.dataset
+ params = importlib.import_module(config_file).params
+ args = argparse.Namespace(**vars(args), **params)
+
+ load_str = "_load" if args.load_path != "" else ""
+ opt_str = "_adam" if args.opt == "adam" else "_sgd"
+ nlabel_str = "_nlabel{}".format(args.num_label)
+ freeze_str = "_freeze" if args.freeze_enc else ""
+
+ if len(args.load_path.split("/")) > 2:
+ load_path_str = args.load_path.split("/")[2]
+ else:
+ load_path_str = args.load_path.split("/")[1]
+
+ model_str = "_{}".format(args.discriminator)
+ # set load and save paths
+ if args.exp_dir == None:
+ args.exp_dir = "models/exp_{}{}_ss_ft/{}{}{}{}{}".format(args.dataset,
+ load_str, load_path_str, model_str, opt_str,
+ nlabel_str, freeze_str)
+ if not os.path.exists(args.exp_dir):
+ os.makedirs(args.exp_dir)
+ args.log_path = os.path.join(args.exp_dir, 'log.txt')
+
+ # set args.label
+ if 'label' in params:
+ args.label = params['label']
+ else:
+ args.label = False
+
+ args.kl_weight = 1
+
+ return args
+
+
+def test(model, test_loader, mode, args, verbose=False):
+
+ report_correct = report_loss = 0
+ report_num_sents = 0
+
+ N=0
+ for datum in test_loader:
+ batch_data, batch_labels = datum
+ batch_data = batch_data.to(args.device)
+ batch_labels = batch_labels.to(args.device).squeeze()
+ #batch_data = torch.bernoulli(batch_data)
+ batch_size = batch_data.size(0)
+
+ # not predict start symbol
+ report_num_sents += batch_size
+ loss, correct = model.get_performance_with_feature(batch_data, batch_labels)
+
+ loss = loss.sum()
+
+ report_loss += loss.item()
+ report_correct += correct
+ N+=1
+
+ test_loss = report_loss / report_num_sents
+ acc = report_correct / report_num_sents
+
+ if verbose:
+ print('%s --- avg_loss: %.4f, acc: %.4f' % \
+ (mode, test_loss, acc))
+ # sys.stdout.flush()
+
+ return test_loss, acc
+
+def train(args,dataset:Omniglot,task,device,trainum=10):
+ x_train,l_train,x_test,l_test,NC = dataset.load_task(task,trainum)
+ train_data = torch.utils.data.TensorDataset(x_train, l_train)
+ test_data = torch.utils.data.TensorDataset(x_test, l_test)
+ train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
+ test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
+ print('Train data: %d samples' % len(x_train))
+ log_niter = max(1, (len(train_data) // (args.batch_size)) // 10)
+
+
+ if args.discriminator == "linear":
+ discriminator = LinearDiscriminator_only(args, NC).to(device)
+
+
+ if args.opt == "sgd":
+ optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum)
+ elif args.opt == "adam":
+ optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
+ # optimizer = swats.SWATS(discriminator.parameters(), lr=0.001)
+
+ else:
+ raise ValueError("optimizer not supported")
+
+ discriminator.train()
+
+ iter_ = 0
+ acc_loss = 0.
+ for epoch in range(args.epochs):
+ report_loss = 0
+ report_correct = report_num_sents = 0
+ acc_batch_size = 0
+ optimizer.zero_grad()
+ for datum in train_loader:
+ batch_data, batch_labels = datum
+ batch_data = batch_data.to(device)
+ batch_labels = batch_labels.to(device).squeeze()
+ batch_size = batch_data.size(0)
+ if batch_data.size(0) < 2:
+ continue
+
+ # not predict start symbol
+ report_num_sents += batch_size
+ acc_batch_size += batch_size
+
+ # (batch_size)
+ loss, correct = discriminator.get_performance_with_feature(batch_data, batch_labels)
+
+ acc_loss = loss.sum()
+ acc_loss.backward()
+ optimizer.step()
+
+ optimizer.zero_grad()
+
+ report_loss += loss.sum().item()
+ report_correct += correct
+ iter_ += 1
+
+ discriminator.eval()
+ with torch.no_grad():
+ loss, acc = test(discriminator, test_loader, "VAL", args, verbose=False)
+ discriminator.train()
+
+ # discriminator.load_state_dict(torch.load(args.save_path))
+ discriminator.eval()
+ with torch.no_grad():
+ loss, acc = test(discriminator, test_loader, "TEST", args,verbose=True)
+ return loss, acc,NC
+
+def main(args):
+ if args.cuda:
+ print('using cuda')
+ print(str(args))
+
+ device = args.device
+
+
+ if args.gamma > 0 and not args.IAF:
+ encoder = BNResNetEncoderV2(args)
+ elif not args.IAF:
+ encoder = ResNetEncoderV2(args)
+ elif args.IAF:
+ encoder = FlowResNetEncoderV2(args)
+
+ decoder = PixelCNNDecoderV2(args,mode='large') # if args.HKL == 'H':
+
+ vae = VAE(encoder, decoder, args).to(device)
+
+ if args.load_path:
+ loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device))
+ vae.load_state_dict(loaded_state_dict)
+ print("%s loaded" % args.load_path)
+ vae.eval()
+
+ dataset = Omniglot(args.root, encoder=vae.encoder, device=device, IAF=args.IAF)
+
+ for tasknum in [5,10,15]:
+ acc_sum=0
+ acc_wsum=0
+ N=0
+ acclist=[]
+ NClist=[]
+ for task in range(50):
+ N+=1
+ loss, acc,NC= train(args,dataset,task,device,tasknum)
+ acc_sum+=acc
+ acc_wsum+=NC*acc
+ acclist.append(acc)
+ NClist.append(NC)
+ acc_mean = acc_sum/N
+ acc_wmean = acc_wsum/sum(NClist)
+
+ print('train_num', tasknum,'acc',acc_mean, acc_wmean)
+ print(acclist)
+ # plt.plot(range(50),acclist,label = 'trainnum%d'%tasknum)
+ # plt.show()
+ # plt.legend()
+
+
+
+
+
+if __name__ == '__main__':
+ args = init_config()
+ sys.stdout = Logger(args.log_path)
+ print('---------------')
+ main(args)
diff --git a/logger.py b/logger.py
new file mode 100755
index 0000000..04cd066
--- /dev/null
+++ b/logger.py
@@ -0,0 +1,17 @@
+"""
+Logger class files
+"""
+import sys
+
+class Logger(object):
+ def __init__(self, output_file):
+ self.terminal = sys.stdout
+ self.log = open(output_file, "w")
+
+ def write(self, message):
+ print(message, end="", file=self.terminal, flush=True)
+ print(message, end="", file=self.log, flush=True)
+
+ def flush(self):
+ self.terminal.flush()
+ self.log.flush()
diff --git a/models/.DS_Store b/models/.DS_Store
new file mode 100644
index 0000000..8d70105
Binary files /dev/null and b/models/.DS_Store differ
diff --git a/models/omniglot/.DS_Store b/models/omniglot/.DS_Store
new file mode 100644
index 0000000..47f6cbb
Binary files /dev/null and b/models/omniglot/.DS_Store differ
diff --git a/models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt b/models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt
new file mode 100644
index 0000000..fee3ace
--- /dev/null
+++ b/models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt
@@ -0,0 +1,6 @@
+Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=0.0, device='cpu', enc_layers=[64, 64, 64], epochs=500, eval=False, gamma=0.0, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, kl_weight=1, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/omniglot/omniglot_KL1.00_dr0.00_nz32_0_0_783435/model.pt', seed=783435, taskid=0, test_nepoch=5, warm_up=10)
+1.5.0
+Train data: 447 batches
+Val data: 40 batches
+Test data: 162 batches
+epoch: 0, iter: 0, avg_loss: 537.0994, kl: 30.8722, mi: 0.3572, recon: 506.2273,au 32, time elapsed 13.60s
diff --git a/models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt b/models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt
new file mode 100644
index 0000000..382017f
--- /dev/null
+++ b/models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt
@@ -0,0 +1,6 @@
+Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=1.0, device='cpu', enc_layers=[64, 64, 64], epochs=500, eval=False, gamma=0.5, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, kl_weight=1, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0.2, reset_dec=False, save_path='models/omniglot/omniglot_KL1.00_gamma0.50_dr1.00_nz32_drop0.20_0_0_783435/model.pt', seed=783435, taskid=0, test_nepoch=5, warm_up=10)
+1.5.0
+Train data: 447 batches
+Val data: 40 batches
+Test data: 162 batches
+epoch: 0, iter: 0, avg_loss: 528.2644, kl: 21.9613, mi: 0.1018, recon: 506.3031,au 0, time elapsed 12.70s
diff --git a/models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt b/models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt
new file mode 100644
index 0000000..6760dac
--- /dev/null
+++ b/models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt
@@ -0,0 +1,12 @@
+Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=0.0, device='cpu', drop_start=1.0, enc_layers=[64, 64, 64], epochs=500, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.0, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/omniglot/omniglot_fb1_tr0.15_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/model.pt', seed=783435, target_kl=0.15, taskid=0, test_nepoch=5, warm_up=10)
+1.5.0
+Train data: 447 batches
+Val data: 40 batches
+Test data: 162 batches
+> [0;32m/Users/shendazhong/Desktop/AAAI21/code_reference/Du-VAE/modules/encoders/enc_flow.py[0m(64)[0;36mencode[0;34m()[0m
+[0;32m 63 [0;31m[0;34m[0m[0m
+[0m[0;32m---> 64 [0;31m [0;32mreturn[0m [0mz_T[0m[0;34m,[0m [0mkl[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;34m[[0m[0;36m1[0m[0;34m,[0m [0;36m2[0m[0;34m][0m[0;34m)[0m [0;31m# like KL[0m[0;34m[0m[0m
+[0m[0;32m 65 [0;31m[0;34m[0m[0m
+[0m
+ipdb> --KeyboardInterrupt--
+ipdb>
\ No newline at end of file
diff --git a/models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt b/models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt
new file mode 100644
index 0000000..bcb0054
--- /dev/null
+++ b/models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt
@@ -0,0 +1,6 @@
+Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=0.0, device='cpu', drop_start=1.0, enc_layers=[64, 64, 64], epochs=500, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.0, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/omniglot/omniglot_fd2_fw2_dr0.00_nz32_0_0_783435_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_nepoch=5, warm_up=10)
+1.5.0
+Train data: 447 batches
+Val data: 40 batches
+Test data: 162 batches
+epoch: 0, iter: 0, avg_loss: 571.8783, kl: 24.7818, mi: 1.0189, recon: 547.0965,au 32, time elapsed 12.62s
diff --git a/models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt b/models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt
new file mode 100644
index 0000000..87581dd
--- /dev/null
+++ b/models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt
@@ -0,0 +1,6 @@
+Namespace(batch_size=50, cuda=False, data_file='data/omniglot_data/omniglot.pt', dataset='omniglot', dec_kernel_size=[9, 9, 9, 7, 7, 7, 5, 5, 5, 3, 3, 3], dec_layers=[32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32], delta_rate=1.0, device='cpu', drop_start=1.0, enc_layers=[64, 64, 64], epochs=500, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.5, img_size=[1, 28, 28], iw_nsamples=500, jobid=0, kl_start=1.0, label=False, latent_feature_map=4, load_path='', log_path='models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/log.txt', nsamples=1, nz=32, nz_new=32, p_drop=0.15, reset_dec=False, save_path='models/omniglot/omniglot_gamma0.50_fd2_fw2_dr1.00_nz32_drop0.15_0_0_783435_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_nepoch=5, warm_up=10)
+1.5.0
+Train data: 447 batches
+Val data: 40 batches
+Test data: 162 batches
+epoch: 0, iter: 0, avg_loss: 564.9081, kl: 17.1125, mi: 0.4995, recon: 547.7956,au 0, time elapsed 11.56s
diff --git a/models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
new file mode 100644
index 0000000..eca0b9d
--- /dev/null
+++ b/models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
@@ -0,0 +1,6 @@
+Namespace(batch_size=32, cuda=False, dataset='short_yelp', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=512, dec_type='lstm', delta_rate=1, device='cpu', enc_nh=512, enc_type='lstm', epochs=100, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=0.0, label=True, load_path='', log_niter=50, log_path='models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=128, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/short_yelp/short_yelp_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/short_yelp_data/short_yelp.test.txt', test_nepoch=5, train_data='data/short_yelp_data/short_yelp.train.txt', val_data='data/short_yelp_data/short_yelp.valid.txt', vocab_file='data/short_yelp_data/vocab.txt', warm_up=10)
+data/short_yelp_data/vocab.txt
+Train data: 100000 samples
+finish reading datasets, vocab size is 8411
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 151.0429, kl/H(z|x): 15.4808, mi: 0.0923, recon: 135.5621,au 0, time elapsed 40.76s
diff --git a/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt
new file mode 100644
index 0000000..68c778f
--- /dev/null
+++ b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt
@@ -0,0 +1,743 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=0.0, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 76.0298, kl/H(z|x): 0.0000, mi: 0.0056, recon: 76.0298,au 0, time elapsed 3.91s
+epoch: 0, iter: 50, avg_loss: 66.4376, kl/H(z|x): 0.0000, mi: 0.0607, recon: 66.4376,au 0, time elapsed 8.59s
+epoch: 0, iter: 100, avg_loss: 60.9033, kl/H(z|x): 0.0000, mi: -0.0787, recon: 60.9032,au 0, time elapsed 13.32s
+epoch: 0, iter: 150, avg_loss: 59.8395, kl/H(z|x): 0.0018, mi: 0.0027, recon: 59.8377,au 0, time elapsed 18.00s
+epoch: 0, iter: 200, avg_loss: 57.6219, kl/H(z|x): 0.0059, mi: 0.0679, recon: 57.6160,au 0, time elapsed 22.66s
+epoch: 0, iter: 250, avg_loss: 54.6628, kl/H(z|x): 0.0098, mi: 0.1069, recon: 54.6529,au 0, time elapsed 27.49s
+epoch: 0, iter: 300, avg_loss: 53.0137, kl/H(z|x): 0.0110, mi: 0.0419, recon: 53.0027,au 0, time elapsed 32.34s
+epoch: 0, iter: 350, avg_loss: 51.6142, kl/H(z|x): 0.0112, mi: 0.0623, recon: 51.6030,au 0, time elapsed 37.05s
+epoch: 0, iter: 400, avg_loss: 51.3852, kl/H(z|x): 0.0135, mi: -0.0299, recon: 51.3717,au 0, time elapsed 41.78s
+epoch: 0, iter: 450, avg_loss: 50.7162, kl/H(z|x): 0.0169, mi: 0.0734, recon: 50.6993,au 0, time elapsed 46.49s
+kl weight 1.0000
+VAL --- avg_loss: 46.5658, kl/H(z|x): 0.0082, mi: -0.0036, recon: 46.5576, nll: 46.5658, ppl: 68.9411
+0 active units
+update best loss
+TEST --- avg_loss: 46.5894, kl/H(z|x): 0.0082, mi: -0.0030, recon: 46.5813, nll: 46.5894, ppl: 69.0896
+epoch: 1, iter: 500, avg_loss: 46.1078, kl/H(z|x): 0.0081, recon: 46.0997,time elapsed 56.18s
+epoch: 1, iter: 550, avg_loss: 48.3830, kl/H(z|x): 0.0233, recon: 48.3596,time elapsed 57.28s
+epoch: 1, iter: 600, avg_loss: 48.6702, kl/H(z|x): 0.0146, recon: 48.6555,time elapsed 58.38s
+epoch: 1, iter: 650, avg_loss: 47.8828, kl/H(z|x): 0.0140, recon: 47.8688,time elapsed 59.44s
+epoch: 1, iter: 700, avg_loss: 47.5835, kl/H(z|x): 0.0077, recon: 47.5758,time elapsed 60.52s
+epoch: 1, iter: 750, avg_loss: 47.6507, kl/H(z|x): 0.0190, recon: 47.6317,time elapsed 61.58s
+epoch: 1, iter: 800, avg_loss: 47.6542, kl/H(z|x): 0.0161, recon: 47.6381,time elapsed 62.70s
+epoch: 1, iter: 850, avg_loss: 47.3931, kl/H(z|x): 0.0086, recon: 47.3845,time elapsed 63.78s
+epoch: 1, iter: 900, avg_loss: 47.7290, kl/H(z|x): 0.0088, recon: 47.7202,time elapsed 64.83s
+epoch: 1, iter: 950, avg_loss: 47.1321, kl/H(z|x): 0.0130, recon: 47.1190,time elapsed 65.92s
+kl weight 1.0000
+VAL --- avg_loss: 45.2807, kl/H(z|x): 0.0073, mi: -0.0388, recon: 45.2735, nll: 45.2807, ppl: 61.3398
+0 active units
+update best loss
+TEST --- avg_loss: 45.2893, kl/H(z|x): 0.0073, mi: 0.0676, recon: 45.2821, nll: 45.2893, ppl: 61.3879
+epoch: 2, iter: 1000, avg_loss: 45.6955, kl/H(z|x): 0.0072, recon: 45.6883,time elapsed 75.45s
+epoch: 2, iter: 1050, avg_loss: 46.7630, kl/H(z|x): 0.0097, recon: 46.7533,time elapsed 76.54s
+epoch: 2, iter: 1100, avg_loss: 47.1458, kl/H(z|x): 0.0081, recon: 47.1377,time elapsed 77.67s
+epoch: 2, iter: 1150, avg_loss: 47.0823, kl/H(z|x): 0.0073, recon: 47.0751,time elapsed 78.73s
+epoch: 2, iter: 1200, avg_loss: 46.7632, kl/H(z|x): 0.0055, recon: 46.7577,time elapsed 79.80s
+epoch: 2, iter: 1250, avg_loss: 47.3431, kl/H(z|x): 0.0089, recon: 47.3342,time elapsed 80.87s
+epoch: 2, iter: 1300, avg_loss: 46.7255, kl/H(z|x): 0.0058, recon: 46.7197,time elapsed 81.92s
+epoch: 2, iter: 1350, avg_loss: 46.7585, kl/H(z|x): 0.0034, recon: 46.7551,time elapsed 83.03s
+epoch: 2, iter: 1400, avg_loss: 45.8850, kl/H(z|x): 0.0032, recon: 45.8818,time elapsed 84.11s
+epoch: 2, iter: 1450, avg_loss: 46.3266, kl/H(z|x): 0.0028, recon: 46.3237,time elapsed 85.29s
+kl weight 1.0000
+VAL --- avg_loss: 44.1531, kl/H(z|x): 0.0026, mi: -0.1020, recon: 44.1505, nll: 44.1531, ppl: 55.3633
+0 active units
+update best loss
+TEST --- avg_loss: 44.1502, kl/H(z|x): 0.0026, mi: 0.0268, recon: 44.1476, nll: 44.1502, ppl: 55.3488
+epoch: 3, iter: 1500, avg_loss: 45.4690, kl/H(z|x): 0.0026, recon: 45.4664,time elapsed 94.90s
+epoch: 3, iter: 1550, avg_loss: 46.3678, kl/H(z|x): 0.0089, recon: 46.3588,time elapsed 95.99s
+epoch: 3, iter: 1600, avg_loss: 46.3185, kl/H(z|x): 0.0036, recon: 46.3148,time elapsed 97.05s
+epoch: 3, iter: 1650, avg_loss: 46.1619, kl/H(z|x): 0.0050, recon: 46.1569,time elapsed 98.13s
+epoch: 3, iter: 1700, avg_loss: 46.0672, kl/H(z|x): 0.0035, recon: 46.0636,time elapsed 99.21s
+epoch: 3, iter: 1750, avg_loss: 46.3371, kl/H(z|x): 0.0024, recon: 46.3347,time elapsed 100.29s
+epoch: 3, iter: 1800, avg_loss: 46.4199, kl/H(z|x): 0.0049, recon: 46.4150,time elapsed 101.36s
+epoch: 3, iter: 1850, avg_loss: 45.9181, kl/H(z|x): 0.0044, recon: 45.9137,time elapsed 102.49s
+epoch: 3, iter: 1900, avg_loss: 46.0571, kl/H(z|x): 0.0087, recon: 46.0484,time elapsed 103.54s
+epoch: 3, iter: 1950, avg_loss: 45.8462, kl/H(z|x): 0.0054, recon: 45.8408,time elapsed 104.61s
+kl weight 1.0000
+VAL --- avg_loss: 44.0262, kl/H(z|x): 0.0022, mi: -0.0017, recon: 44.0241, nll: 44.0262, ppl: 54.7285
+0 active units
+update best loss
+TEST --- avg_loss: 44.0316, kl/H(z|x): 0.0022, mi: -0.1088, recon: 44.0294, nll: 44.0316, ppl: 54.7551
+epoch: 4, iter: 2000, avg_loss: 43.8706, kl/H(z|x): 0.0022, recon: 43.8684,time elapsed 114.37s
+epoch: 4, iter: 2050, avg_loss: 45.8839, kl/H(z|x): 0.0051, recon: 45.8788,time elapsed 115.43s
+epoch: 4, iter: 2100, avg_loss: 46.1152, kl/H(z|x): 0.0088, recon: 46.1065,time elapsed 116.49s
+epoch: 4, iter: 2150, avg_loss: 46.1732, kl/H(z|x): 0.0055, recon: 46.1677,time elapsed 117.59s
+epoch: 4, iter: 2200, avg_loss: 46.1230, kl/H(z|x): 0.0056, recon: 46.1174,time elapsed 118.66s
+epoch: 4, iter: 2250, avg_loss: 45.9930, kl/H(z|x): 0.0097, recon: 45.9833,time elapsed 119.75s
+epoch: 4, iter: 2300, avg_loss: 46.5047, kl/H(z|x): 0.0126, recon: 46.4921,time elapsed 120.81s
+epoch: 4, iter: 2350, avg_loss: 46.0809, kl/H(z|x): 0.0108, recon: 46.0701,time elapsed 121.87s
+epoch: 4, iter: 2400, avg_loss: 46.5641, kl/H(z|x): 0.0058, recon: 46.5583,time elapsed 122.97s
+epoch: 4, iter: 2450, avg_loss: 46.0603, kl/H(z|x): 0.0050, recon: 46.0553,time elapsed 124.03s
+kl weight 1.0000
+VAL --- avg_loss: 44.3252, kl/H(z|x): 0.0032, mi: 0.0661, recon: 44.3220, nll: 44.3252, ppl: 56.2365
+0 active units
+TEST --- avg_loss: 44.3159, kl/H(z|x): 0.0032, mi: 0.0157, recon: 44.3127, nll: 44.3159, ppl: 56.1890
+epoch: 5, iter: 2500, avg_loss: 46.6240, kl/H(z|x): 0.0032, recon: 46.6208,time elapsed 133.62s
+epoch: 5, iter: 2550, avg_loss: 45.8931, kl/H(z|x): 0.0056, recon: 45.8875,time elapsed 134.71s
+epoch: 5, iter: 2600, avg_loss: 45.6455, kl/H(z|x): 0.0056, recon: 45.6399,time elapsed 135.79s
+epoch: 5, iter: 2650, avg_loss: 45.9743, kl/H(z|x): 0.0043, recon: 45.9701,time elapsed 136.84s
+epoch: 5, iter: 2700, avg_loss: 45.7565, kl/H(z|x): 0.0045, recon: 45.7520,time elapsed 137.95s
+epoch: 5, iter: 2750, avg_loss: 45.9835, kl/H(z|x): 0.0050, recon: 45.9785,time elapsed 139.02s
+epoch: 5, iter: 2800, avg_loss: 45.7868, kl/H(z|x): 0.0041, recon: 45.7827,time elapsed 140.11s
+epoch: 5, iter: 2850, avg_loss: 45.7049, kl/H(z|x): 0.0042, recon: 45.7007,time elapsed 141.18s
+epoch: 5, iter: 2900, avg_loss: 46.0195, kl/H(z|x): 0.0034, recon: 46.0161,time elapsed 142.27s
+epoch: 5, iter: 2950, avg_loss: 46.1072, kl/H(z|x): 0.0051, recon: 46.1022,time elapsed 143.38s
+kl weight 1.0000
+VAL --- avg_loss: 43.8109, kl/H(z|x): 0.0038, mi: -0.0220, recon: 43.8071, nll: 43.8109, ppl: 53.6676
+0 active units
+update best loss
+TEST --- avg_loss: 43.8015, kl/H(z|x): 0.0038, mi: -0.0332, recon: 43.7977, nll: 43.8015, ppl: 53.6218
+epoch: 6, iter: 3000, avg_loss: 43.1108, kl/H(z|x): 0.0038, recon: 43.1070,time elapsed 153.22s
+epoch: 6, iter: 3050, avg_loss: 45.8055, kl/H(z|x): 0.0040, recon: 45.8015,time elapsed 154.31s
+epoch: 6, iter: 3100, avg_loss: 45.9045, kl/H(z|x): 0.0038, recon: 45.9007,time elapsed 155.38s
+epoch: 6, iter: 3150, avg_loss: 45.1809, kl/H(z|x): 0.0025, recon: 45.1784,time elapsed 156.45s
+epoch: 6, iter: 3200, avg_loss: 45.6232, kl/H(z|x): 0.0026, recon: 45.6206,time elapsed 157.51s
+epoch: 6, iter: 3250, avg_loss: 45.1418, kl/H(z|x): 0.0020, recon: 45.1398,time elapsed 158.63s
+epoch: 6, iter: 3300, avg_loss: 45.6125, kl/H(z|x): 0.0031, recon: 45.6094,time elapsed 159.70s
+epoch: 6, iter: 3350, avg_loss: 45.3761, kl/H(z|x): 0.0019, recon: 45.3742,time elapsed 160.77s
+epoch: 6, iter: 3400, avg_loss: 45.3928, kl/H(z|x): 0.0022, recon: 45.3906,time elapsed 161.88s
+epoch: 6, iter: 3450, avg_loss: 45.8120, kl/H(z|x): 0.0024, recon: 45.8096,time elapsed 162.97s
+kl weight 1.0000
+VAL --- avg_loss: 43.5564, kl/H(z|x): 0.0021, mi: 0.0195, recon: 43.5543, nll: 43.5564, ppl: 52.4400
+0 active units
+update best loss
+TEST --- avg_loss: 43.5776, kl/H(z|x): 0.0021, mi: 0.0936, recon: 43.5755, nll: 43.5776, ppl: 52.5412
+epoch: 7, iter: 3500, avg_loss: 44.6366, kl/H(z|x): 0.0021, recon: 44.6345,time elapsed 172.75s
+epoch: 7, iter: 3550, avg_loss: 45.0213, kl/H(z|x): 0.0021, recon: 45.0192,time elapsed 173.83s
+epoch: 7, iter: 3600, avg_loss: 45.6376, kl/H(z|x): 0.0029, recon: 45.6347,time elapsed 174.87s
+epoch: 7, iter: 3650, avg_loss: 45.7895, kl/H(z|x): 0.0036, recon: 45.7860,time elapsed 175.95s
+epoch: 7, iter: 3700, avg_loss: 45.5016, kl/H(z|x): 0.0038, recon: 45.4978,time elapsed 177.00s
+epoch: 7, iter: 3750, avg_loss: 45.6823, kl/H(z|x): 0.0032, recon: 45.6791,time elapsed 178.11s
+epoch: 7, iter: 3800, avg_loss: 45.2862, kl/H(z|x): 0.0034, recon: 45.2827,time elapsed 179.19s
+epoch: 7, iter: 3850, avg_loss: 45.3230, kl/H(z|x): 0.0029, recon: 45.3201,time elapsed 180.27s
+epoch: 7, iter: 3900, avg_loss: 45.1383, kl/H(z|x): 0.0025, recon: 45.1358,time elapsed 181.35s
+epoch: 7, iter: 3950, avg_loss: 45.4385, kl/H(z|x): 0.0027, recon: 45.4358,time elapsed 182.47s
+kl weight 1.0000
+VAL --- avg_loss: 43.6658, kl/H(z|x): 0.0022, mi: 0.0445, recon: 43.6635, nll: 43.6658, ppl: 52.9641
+0 active units
+TEST --- avg_loss: 43.6665, kl/H(z|x): 0.0022, mi: 0.0049, recon: 43.6643, nll: 43.6665, ppl: 52.9676
+epoch: 8, iter: 4000, avg_loss: 46.5223, kl/H(z|x): 0.0022, recon: 46.5200,time elapsed 191.89s
+epoch: 8, iter: 4050, avg_loss: 44.8934, kl/H(z|x): 0.0035, recon: 44.8899,time elapsed 193.07s
+epoch: 8, iter: 4100, avg_loss: 45.4792, kl/H(z|x): 0.0068, recon: 45.4724,time elapsed 194.14s
+epoch: 8, iter: 4150, avg_loss: 45.5920, kl/H(z|x): 0.0026, recon: 45.5894,time elapsed 195.25s
+epoch: 8, iter: 4200, avg_loss: 45.3591, kl/H(z|x): 0.0034, recon: 45.3557,time elapsed 196.38s
+epoch: 8, iter: 4250, avg_loss: 45.7200, kl/H(z|x): 0.0038, recon: 45.7162,time elapsed 197.49s
+epoch: 8, iter: 4300, avg_loss: 45.4247, kl/H(z|x): 0.0022, recon: 45.4225,time elapsed 198.56s
+epoch: 8, iter: 4350, avg_loss: 45.4409, kl/H(z|x): 0.0042, recon: 45.4367,time elapsed 199.64s
+epoch: 8, iter: 4400, avg_loss: 45.1211, kl/H(z|x): 0.0025, recon: 45.1186,time elapsed 200.72s
+epoch: 8, iter: 4450, avg_loss: 45.2950, kl/H(z|x): 0.0024, recon: 45.2926,time elapsed 201.78s
+kl weight 1.0000
+VAL --- avg_loss: 43.5856, kl/H(z|x): 0.0023, mi: 0.1106, recon: 43.5833, nll: 43.5856, ppl: 52.5797
+0 active units
+TEST --- avg_loss: 43.6012, kl/H(z|x): 0.0023, mi: -0.0433, recon: 43.5989, nll: 43.6012, ppl: 52.6544
+epoch: 9, iter: 4500, avg_loss: 44.3652, kl/H(z|x): 0.0023, recon: 44.3629,time elapsed 211.64s
+epoch: 9, iter: 4550, avg_loss: 45.4353, kl/H(z|x): 0.0063, recon: 45.4290,time elapsed 212.78s
+epoch: 9, iter: 4600, avg_loss: 45.5352, kl/H(z|x): 0.0076, recon: 45.5276,time elapsed 213.88s
+epoch: 9, iter: 4650, avg_loss: 44.9239, kl/H(z|x): 0.0022, recon: 44.9218,time elapsed 214.96s
+epoch: 9, iter: 4700, avg_loss: 44.9674, kl/H(z|x): 0.0025, recon: 44.9649,time elapsed 216.06s
+epoch: 9, iter: 4750, avg_loss: 45.3179, kl/H(z|x): 0.0024, recon: 45.3154,time elapsed 217.15s
+epoch: 9, iter: 4800, avg_loss: 45.2692, kl/H(z|x): 0.0022, recon: 45.2670,time elapsed 218.25s
+epoch: 9, iter: 4850, avg_loss: 45.2401, kl/H(z|x): 0.0017, recon: 45.2384,time elapsed 219.32s
+epoch: 9, iter: 4900, avg_loss: 45.3045, kl/H(z|x): 0.0019, recon: 45.3026,time elapsed 220.39s
+epoch: 9, iter: 4950, avg_loss: 45.3250, kl/H(z|x): 0.0023, recon: 45.3228,time elapsed 221.44s
+kl weight 1.0000
+VAL --- avg_loss: 43.6818, kl/H(z|x): 0.0018, mi: -0.1053, recon: 43.6799, nll: 43.6818, ppl: 53.0412
+0 active units
+TEST --- avg_loss: 43.6995, kl/H(z|x): 0.0018, mi: 0.1338, recon: 43.6977, nll: 43.6995, ppl: 53.1268
+epoch: 10, iter: 5000, avg_loss: 44.2786, kl/H(z|x): 0.0019, recon: 44.2768,time elapsed 230.94s
+epoch: 10, iter: 5050, avg_loss: 45.4548, kl/H(z|x): 0.0018, recon: 45.4530,time elapsed 232.11s
+epoch: 10, iter: 5100, avg_loss: 45.1848, kl/H(z|x): 0.0019, recon: 45.1830,time elapsed 233.17s
+epoch: 10, iter: 5150, avg_loss: 44.9300, kl/H(z|x): 0.0021, recon: 44.9279,time elapsed 234.25s
+epoch: 10, iter: 5200, avg_loss: 45.4789, kl/H(z|x): 0.0024, recon: 45.4765,time elapsed 235.36s
+epoch: 10, iter: 5250, avg_loss: 45.0654, kl/H(z|x): 0.0024, recon: 45.0630,time elapsed 236.44s
+epoch: 10, iter: 5300, avg_loss: 45.7442, kl/H(z|x): 0.0025, recon: 45.7417,time elapsed 237.52s
+epoch: 10, iter: 5350, avg_loss: 45.4598, kl/H(z|x): 0.0032, recon: 45.4566,time elapsed 238.61s
+epoch: 10, iter: 5400, avg_loss: 45.2135, kl/H(z|x): 0.0019, recon: 45.2116,time elapsed 239.67s
+epoch: 10, iter: 5450, avg_loss: 44.7176, kl/H(z|x): 0.0018, recon: 44.7159,time elapsed 240.73s
+kl weight 1.0000
+VAL --- avg_loss: 43.5995, kl/H(z|x): 0.0014, mi: 0.0867, recon: 43.5981, nll: 43.5995, ppl: 52.6458
+0 active units
+TEST --- avg_loss: 43.5949, kl/H(z|x): 0.0014, mi: -0.0502, recon: 43.5935, nll: 43.5949, ppl: 52.6239
+epoch: 11, iter: 5500, avg_loss: 43.6854, kl/H(z|x): 0.0014, recon: 43.6840,time elapsed 250.17s
+epoch: 11, iter: 5550, avg_loss: 44.8833, kl/H(z|x): 0.0018, recon: 44.8815,time elapsed 251.27s
+epoch: 11, iter: 5600, avg_loss: 45.3650, kl/H(z|x): 0.0018, recon: 45.3632,time elapsed 252.37s
+epoch: 11, iter: 5650, avg_loss: 45.1862, kl/H(z|x): 0.0023, recon: 45.1839,time elapsed 253.52s
+epoch: 11, iter: 5700, avg_loss: 45.1249, kl/H(z|x): 0.0020, recon: 45.1229,time elapsed 254.60s
+epoch: 11, iter: 5750, avg_loss: 45.1719, kl/H(z|x): 0.0020, recon: 45.1699,time elapsed 255.67s
+epoch: 11, iter: 5800, avg_loss: 44.8795, kl/H(z|x): 0.0018, recon: 44.8778,time elapsed 256.74s
+epoch: 11, iter: 5850, avg_loss: 44.8756, kl/H(z|x): 0.0020, recon: 44.8735,time elapsed 257.85s
+epoch: 11, iter: 5900, avg_loss: 45.5389, kl/H(z|x): 0.0023, recon: 45.5366,time elapsed 258.92s
+epoch: 11, iter: 5950, avg_loss: 44.9864, kl/H(z|x): 0.0022, recon: 44.9843,time elapsed 260.01s
+kl weight 1.0000
+VAL --- avg_loss: 43.4362, kl/H(z|x): 0.0013, mi: -0.0682, recon: 43.4349, nll: 43.4362, ppl: 51.8702
+0 active units
+update best loss
+TEST --- avg_loss: 43.4197, kl/H(z|x): 0.0013, mi: 0.0301, recon: 43.4184, nll: 43.4197, ppl: 51.7926
+epoch: 12, iter: 6000, avg_loss: 42.4735, kl/H(z|x): 0.0013, recon: 42.4721,time elapsed 269.71s
+epoch: 12, iter: 6050, avg_loss: 44.9651, kl/H(z|x): 0.0019, recon: 44.9631,time elapsed 270.82s
+epoch: 12, iter: 6100, avg_loss: 45.2931, kl/H(z|x): 0.0020, recon: 45.2912,time elapsed 271.90s
+epoch: 12, iter: 6150, avg_loss: 45.0727, kl/H(z|x): 0.0021, recon: 45.0707,time elapsed 272.95s
+epoch: 12, iter: 6200, avg_loss: 44.5651, kl/H(z|x): 0.0017, recon: 44.5635,time elapsed 274.02s
+epoch: 12, iter: 6250, avg_loss: 44.6325, kl/H(z|x): 0.0017, recon: 44.6308,time elapsed 275.08s
+epoch: 12, iter: 6300, avg_loss: 45.3624, kl/H(z|x): 0.0018, recon: 45.3605,time elapsed 276.14s
+epoch: 12, iter: 6350, avg_loss: 45.1114, kl/H(z|x): 0.0024, recon: 45.1090,time elapsed 277.23s
+epoch: 12, iter: 6400, avg_loss: 45.0876, kl/H(z|x): 0.0025, recon: 45.0852,time elapsed 278.34s
+epoch: 12, iter: 6450, avg_loss: 45.3110, kl/H(z|x): 0.0023, recon: 45.3087,time elapsed 279.39s
+kl weight 1.0000
+VAL --- avg_loss: 43.3657, kl/H(z|x): 0.0018, mi: 0.1527, recon: 43.3640, nll: 43.3657, ppl: 51.5389
+0 active units
+update best loss
+TEST --- avg_loss: 43.3642, kl/H(z|x): 0.0018, mi: -0.1021, recon: 43.3624, nll: 43.3642, ppl: 51.5317
+epoch: 13, iter: 6500, avg_loss: 47.0057, kl/H(z|x): 0.0017, recon: 47.0040,time elapsed 288.82s
+epoch: 13, iter: 6550, avg_loss: 44.6513, kl/H(z|x): 0.0027, recon: 44.6486,time elapsed 289.90s
+epoch: 13, iter: 6600, avg_loss: 45.1560, kl/H(z|x): 0.0015, recon: 45.1545,time elapsed 291.00s
+epoch: 13, iter: 6650, avg_loss: 45.0566, kl/H(z|x): 0.0019, recon: 45.0547,time elapsed 292.29s
+epoch: 13, iter: 6700, avg_loss: 45.3155, kl/H(z|x): 0.0051, recon: 45.3104,time elapsed 293.44s
+epoch: 13, iter: 6750, avg_loss: 45.5072, kl/H(z|x): 0.0033, recon: 45.5039,time elapsed 294.55s
+epoch: 13, iter: 6800, avg_loss: 45.3792, kl/H(z|x): 0.0046, recon: 45.3747,time elapsed 295.67s
+epoch: 13, iter: 6850, avg_loss: 45.1679, kl/H(z|x): 0.0026, recon: 45.1653,time elapsed 296.80s
+epoch: 13, iter: 6900, avg_loss: 45.7968, kl/H(z|x): 0.0042, recon: 45.7926,time elapsed 297.90s
+epoch: 13, iter: 6950, avg_loss: 45.2510, kl/H(z|x): 0.0045, recon: 45.2464,time elapsed 299.03s
+kl weight 1.0000
+VAL --- avg_loss: 43.3716, kl/H(z|x): 0.0023, mi: -0.0546, recon: 43.3693, nll: 43.3716, ppl: 51.5668
+0 active units
+TEST --- avg_loss: 43.3622, kl/H(z|x): 0.0023, mi: -0.0436, recon: 43.3599, nll: 43.3622, ppl: 51.5224
+epoch: 14, iter: 7000, avg_loss: 41.5074, kl/H(z|x): 0.0023, recon: 41.5051,time elapsed 308.56s
+epoch: 14, iter: 7050, avg_loss: 45.4000, kl/H(z|x): 0.0030, recon: 45.3970,time elapsed 309.65s
+epoch: 14, iter: 7100, avg_loss: 44.7859, kl/H(z|x): 0.0028, recon: 44.7830,time elapsed 310.71s
+epoch: 14, iter: 7150, avg_loss: 45.1264, kl/H(z|x): 0.0026, recon: 45.1238,time elapsed 311.79s
+epoch: 14, iter: 7200, avg_loss: 45.0590, kl/H(z|x): 0.0023, recon: 45.0566,time elapsed 312.87s
+epoch: 14, iter: 7250, avg_loss: 45.1674, kl/H(z|x): 0.0028, recon: 45.1646,time elapsed 314.04s
+epoch: 14, iter: 7300, avg_loss: 45.4094, kl/H(z|x): 0.0038, recon: 45.4056,time elapsed 315.11s
+epoch: 14, iter: 7350, avg_loss: 44.7291, kl/H(z|x): 0.0041, recon: 44.7250,time elapsed 316.20s
+epoch: 14, iter: 7400, avg_loss: 45.2026, kl/H(z|x): 0.0023, recon: 45.2003,time elapsed 317.33s
+epoch: 14, iter: 7450, avg_loss: 45.1990, kl/H(z|x): 0.0022, recon: 45.1968,time elapsed 318.41s
+kl weight 1.0000
+VAL --- avg_loss: 43.3621, kl/H(z|x): 0.0028, mi: -0.0980, recon: 43.3593, nll: 43.3621, ppl: 51.5220
+0 active units
+update best loss
+TEST --- avg_loss: 43.3668, kl/H(z|x): 0.0028, mi: 0.0047, recon: 43.3640, nll: 43.3668, ppl: 51.5442
+epoch: 15, iter: 7500, avg_loss: 45.3504, kl/H(z|x): 0.0028, recon: 45.3475,time elapsed 328.06s
+epoch: 15, iter: 7550, avg_loss: 44.6513, kl/H(z|x): 0.0034, recon: 44.6480,time elapsed 329.16s
+epoch: 15, iter: 7600, avg_loss: 44.5779, kl/H(z|x): 0.0029, recon: 44.5750,time elapsed 330.27s
+epoch: 15, iter: 7650, avg_loss: 45.3756, kl/H(z|x): 0.0024, recon: 45.3732,time elapsed 331.43s
+epoch: 15, iter: 7700, avg_loss: 45.0796, kl/H(z|x): 0.0040, recon: 45.0756,time elapsed 332.54s
+epoch: 15, iter: 7750, avg_loss: 44.8463, kl/H(z|x): 0.0028, recon: 44.8435,time elapsed 333.66s
+epoch: 15, iter: 7800, avg_loss: 45.4175, kl/H(z|x): 0.0040, recon: 45.4135,time elapsed 334.77s
+epoch: 15, iter: 7850, avg_loss: 45.3280, kl/H(z|x): 0.0024, recon: 45.3256,time elapsed 335.88s
+epoch: 15, iter: 7900, avg_loss: 45.0953, kl/H(z|x): 0.0026, recon: 45.0926,time elapsed 336.97s
+epoch: 15, iter: 7950, avg_loss: 45.5351, kl/H(z|x): 0.0019, recon: 45.5332,time elapsed 338.10s
+kl weight 1.0000
+VAL --- avg_loss: 43.7256, kl/H(z|x): 0.0015, mi: -0.0466, recon: 43.7241, nll: 43.7256, ppl: 53.2529
+0 active units
+TEST --- avg_loss: 43.7168, kl/H(z|x): 0.0015, mi: 0.0051, recon: 43.7153, nll: 43.7168, ppl: 53.2103
+epoch: 16, iter: 8000, avg_loss: 48.4416, kl/H(z|x): 0.0015, recon: 48.4401,time elapsed 347.57s
+epoch: 16, iter: 8050, avg_loss: 44.7687, kl/H(z|x): 0.0016, recon: 44.7672,time elapsed 348.68s
+epoch: 16, iter: 8100, avg_loss: 44.4672, kl/H(z|x): 0.0016, recon: 44.4656,time elapsed 349.77s
+epoch: 16, iter: 8150, avg_loss: 45.1387, kl/H(z|x): 0.0014, recon: 45.1373,time elapsed 350.86s
+epoch: 16, iter: 8200, avg_loss: 45.0043, kl/H(z|x): 0.0031, recon: 45.0012,time elapsed 352.23s
+epoch: 16, iter: 8250, avg_loss: 45.2002, kl/H(z|x): 0.0018, recon: 45.1984,time elapsed 353.31s
+epoch: 16, iter: 8300, avg_loss: 45.3384, kl/H(z|x): 0.0020, recon: 45.3364,time elapsed 354.36s
+epoch: 16, iter: 8350, avg_loss: 45.0697, kl/H(z|x): 0.0023, recon: 45.0674,time elapsed 355.48s
+epoch: 16, iter: 8400, avg_loss: 45.1004, kl/H(z|x): 0.0015, recon: 45.0988,time elapsed 356.56s
+epoch: 16, iter: 8450, avg_loss: 44.7239, kl/H(z|x): 0.0019, recon: 44.7219,time elapsed 357.69s
+kl weight 1.0000
+VAL --- avg_loss: 43.3528, kl/H(z|x): 0.0016, mi: -0.0783, recon: 43.3512, nll: 43.3528, ppl: 51.4786
+0 active units
+update best loss
+TEST --- avg_loss: 43.3674, kl/H(z|x): 0.0016, mi: -0.0893, recon: 43.3658, nll: 43.3674, ppl: 51.5467
+epoch: 17, iter: 8500, avg_loss: 44.1956, kl/H(z|x): 0.0016, recon: 44.1940,time elapsed 367.18s
+epoch: 17, iter: 8550, avg_loss: 44.4295, kl/H(z|x): 0.0022, recon: 44.4273,time elapsed 368.29s
+epoch: 17, iter: 8600, avg_loss: 45.0237, kl/H(z|x): 0.0025, recon: 45.0211,time elapsed 369.39s
+epoch: 17, iter: 8650, avg_loss: 45.2083, kl/H(z|x): 0.0016, recon: 45.2067,time elapsed 370.48s
+epoch: 17, iter: 8700, avg_loss: 45.0486, kl/H(z|x): 0.0021, recon: 45.0465,time elapsed 371.57s
+epoch: 17, iter: 8750, avg_loss: 45.1202, kl/H(z|x): 0.0022, recon: 45.1180,time elapsed 372.68s
+epoch: 17, iter: 8800, avg_loss: 45.5614, kl/H(z|x): 0.0018, recon: 45.5596,time elapsed 373.76s
+epoch: 17, iter: 8850, avg_loss: 44.9460, kl/H(z|x): 0.0018, recon: 44.9442,time elapsed 374.91s
+epoch: 17, iter: 8900, avg_loss: 45.2071, kl/H(z|x): 0.0018, recon: 45.2053,time elapsed 376.01s
+epoch: 17, iter: 8950, avg_loss: 45.1195, kl/H(z|x): 0.0020, recon: 45.1175,time elapsed 377.09s
+kl weight 1.0000
+VAL --- avg_loss: 43.5796, kl/H(z|x): 0.0035, mi: 0.1645, recon: 43.5761, nll: 43.5796, ppl: 52.5508
+0 active units
+TEST --- avg_loss: 43.5770, kl/H(z|x): 0.0035, mi: -0.0107, recon: 43.5734, nll: 43.5770, ppl: 52.5384
+epoch: 18, iter: 9000, avg_loss: 44.2011, kl/H(z|x): 0.0035, recon: 44.1976,time elapsed 386.56s
+epoch: 18, iter: 9050, avg_loss: 45.1314, kl/H(z|x): 0.0018, recon: 45.1296,time elapsed 387.79s
+epoch: 18, iter: 9100, avg_loss: 44.8351, kl/H(z|x): 0.0019, recon: 44.8331,time elapsed 388.95s
+epoch: 18, iter: 9150, avg_loss: 44.7113, kl/H(z|x): 0.0019, recon: 44.7095,time elapsed 390.06s
+epoch: 18, iter: 9200, avg_loss: 44.9135, kl/H(z|x): 0.0016, recon: 44.9119,time elapsed 391.16s
+epoch: 18, iter: 9250, avg_loss: 44.8027, kl/H(z|x): 0.0023, recon: 44.8005,time elapsed 392.26s
+epoch: 18, iter: 9300, avg_loss: 45.2509, kl/H(z|x): 0.0025, recon: 45.2485,time elapsed 393.38s
+epoch: 18, iter: 9350, avg_loss: 45.4845, kl/H(z|x): 0.0023, recon: 45.4823,time elapsed 394.47s
+epoch: 18, iter: 9400, avg_loss: 45.1169, kl/H(z|x): 0.0036, recon: 45.1133,time elapsed 395.54s
+epoch: 18, iter: 9450, avg_loss: 45.0380, kl/H(z|x): 0.0023, recon: 45.0357,time elapsed 396.64s
+kl weight 1.0000
+VAL --- avg_loss: 43.2472, kl/H(z|x): 0.0016, mi: -0.0371, recon: 43.2456, nll: 43.2472, ppl: 50.9865
+0 active units
+update best loss
+TEST --- avg_loss: 43.2442, kl/H(z|x): 0.0016, mi: -0.0985, recon: 43.2427, nll: 43.2442, ppl: 50.9729
+epoch: 19, iter: 9500, avg_loss: 43.3491, kl/H(z|x): 0.0016, recon: 43.3475,time elapsed 406.21s
+epoch: 19, iter: 9550, avg_loss: 44.9036, kl/H(z|x): 0.0019, recon: 44.9017,time elapsed 407.30s
+epoch: 19, iter: 9600, avg_loss: 45.0755, kl/H(z|x): 0.0016, recon: 45.0739,time elapsed 408.42s
+epoch: 19, iter: 9650, avg_loss: 44.7566, kl/H(z|x): 0.0016, recon: 44.7551,time elapsed 409.54s
+epoch: 19, iter: 9700, avg_loss: 44.9752, kl/H(z|x): 0.0015, recon: 44.9737,time elapsed 410.67s
+epoch: 19, iter: 9750, avg_loss: 44.8479, kl/H(z|x): 0.0018, recon: 44.8461,time elapsed 411.93s
+epoch: 19, iter: 9800, avg_loss: 45.3882, kl/H(z|x): 0.0016, recon: 45.3866,time elapsed 413.19s
+epoch: 19, iter: 9850, avg_loss: 44.9943, kl/H(z|x): 0.0016, recon: 44.9927,time elapsed 414.33s
+epoch: 19, iter: 9900, avg_loss: 44.6496, kl/H(z|x): 0.0025, recon: 44.6471,time elapsed 415.42s
+epoch: 19, iter: 9950, avg_loss: 44.8607, kl/H(z|x): 0.0018, recon: 44.8588,time elapsed 416.51s
+kl weight 1.0000
+VAL --- avg_loss: 43.2890, kl/H(z|x): 0.0025, mi: -0.0208, recon: 43.2864, nll: 43.2890, ppl: 51.1807
+0 active units
+TEST --- avg_loss: 43.2645, kl/H(z|x): 0.0025, mi: -0.0091, recon: 43.2619, nll: 43.2645, ppl: 51.0668
+epoch: 20, iter: 10000, avg_loss: 46.8134, kl/H(z|x): 0.0025, recon: 46.8108,time elapsed 426.14s
+epoch: 20, iter: 10050, avg_loss: 44.5414, kl/H(z|x): 0.0015, recon: 44.5399,time elapsed 427.27s
+epoch: 20, iter: 10100, avg_loss: 44.8945, kl/H(z|x): 0.0022, recon: 44.8923,time elapsed 428.39s
+epoch: 20, iter: 10150, avg_loss: 44.5182, kl/H(z|x): 0.0016, recon: 44.5166,time elapsed 429.50s
+epoch: 20, iter: 10200, avg_loss: 44.9697, kl/H(z|x): 0.0016, recon: 44.9681,time elapsed 430.60s
+epoch: 20, iter: 10250, avg_loss: 45.2996, kl/H(z|x): 0.0024, recon: 45.2973,time elapsed 431.72s
+epoch: 20, iter: 10300, avg_loss: 45.0469, kl/H(z|x): 0.0023, recon: 45.0446,time elapsed 432.89s
+epoch: 20, iter: 10350, avg_loss: 44.9282, kl/H(z|x): 0.0021, recon: 44.9261,time elapsed 434.01s
+epoch: 20, iter: 10400, avg_loss: 45.3862, kl/H(z|x): 0.0018, recon: 45.3844,time elapsed 435.10s
+epoch: 20, iter: 10450, avg_loss: 44.9310, kl/H(z|x): 0.0017, recon: 44.9292,time elapsed 436.25s
+kl weight 1.0000
+VAL --- avg_loss: 43.3190, kl/H(z|x): 0.0026, mi: -0.0782, recon: 43.3164, nll: 43.3190, ppl: 51.3205
+0 active units
+TEST --- avg_loss: 43.3047, kl/H(z|x): 0.0026, mi: 0.0224, recon: 43.3021, nll: 43.3047, ppl: 51.2536
+epoch: 21, iter: 10500, avg_loss: 43.5460, kl/H(z|x): 0.0026, recon: 43.5434,time elapsed 445.80s
+epoch: 21, iter: 10550, avg_loss: 44.5981, kl/H(z|x): 0.0017, recon: 44.5964,time elapsed 446.89s
+epoch: 21, iter: 10600, avg_loss: 45.0290, kl/H(z|x): 0.0018, recon: 45.0272,time elapsed 448.20s
+epoch: 21, iter: 10650, avg_loss: 44.9590, kl/H(z|x): 0.0015, recon: 44.9575,time elapsed 449.30s
+epoch: 21, iter: 10700, avg_loss: 44.7474, kl/H(z|x): 0.0013, recon: 44.7460,time elapsed 450.43s
+epoch: 21, iter: 10750, avg_loss: 44.8494, kl/H(z|x): 0.0014, recon: 44.8480,time elapsed 451.53s
+epoch: 21, iter: 10800, avg_loss: 44.7393, kl/H(z|x): 0.0016, recon: 44.7377,time elapsed 452.66s
+epoch: 21, iter: 10850, avg_loss: 44.8197, kl/H(z|x): 0.0019, recon: 44.8178,time elapsed 453.75s
+epoch: 21, iter: 10900, avg_loss: 45.2948, kl/H(z|x): 0.0025, recon: 45.2923,time elapsed 454.83s
+epoch: 21, iter: 10950, avg_loss: 45.0870, kl/H(z|x): 0.0012, recon: 45.0859,time elapsed 455.92s
+kl weight 1.0000
+VAL --- avg_loss: 43.4467, kl/H(z|x): 0.0016, mi: -0.0297, recon: 43.4450, nll: 43.4467, ppl: 51.9196
+0 active units
+TEST --- avg_loss: 43.4356, kl/H(z|x): 0.0016, mi: 0.0437, recon: 43.4339, nll: 43.4356, ppl: 51.8673
+epoch: 22, iter: 11000, avg_loss: 46.4844, kl/H(z|x): 0.0016, recon: 46.4828,time elapsed 465.42s
+epoch: 22, iter: 11050, avg_loss: 44.5962, kl/H(z|x): 0.0016, recon: 44.5946,time elapsed 466.52s
+epoch: 22, iter: 11100, avg_loss: 44.8375, kl/H(z|x): 0.0019, recon: 44.8355,time elapsed 467.66s
+epoch: 22, iter: 11150, avg_loss: 44.6757, kl/H(z|x): 0.0015, recon: 44.6742,time elapsed 468.76s
+epoch: 22, iter: 11200, avg_loss: 45.0933, kl/H(z|x): 0.0015, recon: 45.0918,time elapsed 469.87s
+epoch: 22, iter: 11250, avg_loss: 45.0801, kl/H(z|x): 0.0017, recon: 45.0784,time elapsed 470.97s
+epoch: 22, iter: 11300, avg_loss: 44.6852, kl/H(z|x): 0.0017, recon: 44.6835,time elapsed 472.09s
+epoch: 22, iter: 11350, avg_loss: 45.4787, kl/H(z|x): 0.0016, recon: 45.4771,time elapsed 473.22s
+epoch: 22, iter: 11400, avg_loss: 44.6543, kl/H(z|x): 0.0016, recon: 44.6527,time elapsed 474.32s
+epoch: 22, iter: 11450, avg_loss: 44.8205, kl/H(z|x): 0.0013, recon: 44.8191,time elapsed 475.42s
+kl weight 1.0000
+VAL --- avg_loss: 43.5634, kl/H(z|x): 0.0009, mi: -0.0111, recon: 43.5625, nll: 43.5634, ppl: 52.4734
+0 active units
+TEST --- avg_loss: 43.5467, kl/H(z|x): 0.0009, mi: -0.0235, recon: 43.5458, nll: 43.5467, ppl: 52.3937
+epoch: 23, iter: 11500, avg_loss: 48.4146, kl/H(z|x): 0.0009, recon: 48.4137,time elapsed 484.94s
+epoch: 23, iter: 11550, avg_loss: 45.0193, kl/H(z|x): 0.0015, recon: 45.0178,time elapsed 486.04s
+epoch: 23, iter: 11600, avg_loss: 44.6590, kl/H(z|x): 0.0015, recon: 44.6575,time elapsed 487.14s
+epoch: 23, iter: 11650, avg_loss: 44.8284, kl/H(z|x): 0.0015, recon: 44.8269,time elapsed 488.26s
+epoch: 23, iter: 11700, avg_loss: 44.7549, kl/H(z|x): 0.0013, recon: 44.7536,time elapsed 489.36s
+epoch: 23, iter: 11750, avg_loss: 44.8092, kl/H(z|x): 0.0013, recon: 44.8079,time elapsed 490.46s
+epoch: 23, iter: 11800, avg_loss: 44.7132, kl/H(z|x): 0.0016, recon: 44.7116,time elapsed 491.59s
+epoch: 23, iter: 11850, avg_loss: 45.3561, kl/H(z|x): 0.0012, recon: 45.3548,time elapsed 492.71s
+epoch: 23, iter: 11900, avg_loss: 44.7064, kl/H(z|x): 0.0013, recon: 44.7051,time elapsed 493.79s
+epoch: 23, iter: 11950, avg_loss: 44.9692, kl/H(z|x): 0.0013, recon: 44.9679,time elapsed 494.90s
+kl weight 1.0000
+VAL --- avg_loss: 43.3256, kl/H(z|x): 0.0012, mi: -0.0326, recon: 43.3244, nll: 43.3256, ppl: 51.3512
+0 active units
+new lr: 0.500000
+TEST --- avg_loss: 43.2480, kl/H(z|x): 0.0016, mi: -0.0007, recon: 43.2465, nll: 43.2480, ppl: 50.9905
+epoch: 24, iter: 12000, avg_loss: 46.3091, kl/H(z|x): 0.0016, recon: 46.3076,time elapsed 504.33s
+epoch: 24, iter: 12050, avg_loss: 44.2064, kl/H(z|x): 0.0017, recon: 44.2047,time elapsed 505.42s
+epoch: 24, iter: 12100, avg_loss: 44.2650, kl/H(z|x): 0.0013, recon: 44.2637,time elapsed 506.51s
+epoch: 24, iter: 12150, avg_loss: 44.8498, kl/H(z|x): 0.0013, recon: 44.8485,time elapsed 507.73s
+epoch: 24, iter: 12200, avg_loss: 44.0413, kl/H(z|x): 0.0010, recon: 44.0403,time elapsed 508.88s
+epoch: 24, iter: 12250, avg_loss: 44.0497, kl/H(z|x): 0.0009, recon: 44.0488,time elapsed 509.98s
+epoch: 24, iter: 12300, avg_loss: 43.7201, kl/H(z|x): 0.0009, recon: 43.7193,time elapsed 511.08s
+epoch: 24, iter: 12350, avg_loss: 44.5270, kl/H(z|x): 0.0012, recon: 44.5259,time elapsed 512.16s
+epoch: 24, iter: 12400, avg_loss: 44.4890, kl/H(z|x): 0.0010, recon: 44.4880,time elapsed 513.28s
+epoch: 24, iter: 12450, avg_loss: 44.3600, kl/H(z|x): 0.0009, recon: 44.3591,time elapsed 514.36s
+kl weight 1.0000
+VAL --- avg_loss: 42.8909, kl/H(z|x): 0.0007, mi: -0.0319, recon: 42.8902, nll: 42.8909, ppl: 49.3614
+0 active units
+update best loss
+TEST --- avg_loss: 42.8864, kl/H(z|x): 0.0007, mi: 0.0619, recon: 42.8857, nll: 42.8864, ppl: 49.3415
+epoch: 25, iter: 12500, avg_loss: 43.0568, kl/H(z|x): 0.0007, recon: 43.0561,time elapsed 523.70s
+epoch: 25, iter: 12550, avg_loss: 43.7582, kl/H(z|x): 0.0008, recon: 43.7573,time elapsed 524.80s
+epoch: 25, iter: 12600, avg_loss: 43.9676, kl/H(z|x): 0.0005, recon: 43.9671,time elapsed 525.90s
+epoch: 25, iter: 12650, avg_loss: 44.1848, kl/H(z|x): 0.0005, recon: 44.1843,time elapsed 527.00s
+epoch: 25, iter: 12700, avg_loss: 44.3981, kl/H(z|x): 0.0005, recon: 44.3976,time elapsed 528.14s
+epoch: 25, iter: 12750, avg_loss: 44.3295, kl/H(z|x): 0.0005, recon: 44.3290,time elapsed 529.22s
+epoch: 25, iter: 12800, avg_loss: 44.1709, kl/H(z|x): 0.0004, recon: 44.1705,time elapsed 530.32s
+epoch: 25, iter: 12850, avg_loss: 43.7556, kl/H(z|x): 0.0005, recon: 43.7550,time elapsed 531.78s
+epoch: 25, iter: 12900, avg_loss: 44.2555, kl/H(z|x): 0.0004, recon: 44.2550,time elapsed 533.00s
+epoch: 25, iter: 12950, avg_loss: 44.1798, kl/H(z|x): 0.0004, recon: 44.1793,time elapsed 534.12s
+kl weight 1.0000
+VAL --- avg_loss: 42.9512, kl/H(z|x): 0.0005, mi: -0.0766, recon: 42.9507, nll: 42.9512, ppl: 49.6328
+0 active units
+TEST --- avg_loss: 42.9562, kl/H(z|x): 0.0005, mi: -0.0393, recon: 42.9557, nll: 42.9562, ppl: 49.6553
+epoch: 26, iter: 13000, avg_loss: 44.2116, kl/H(z|x): 0.0005, recon: 44.2112,time elapsed 543.43s
+epoch: 26, iter: 13050, avg_loss: 43.9048, kl/H(z|x): 0.0004, recon: 43.9044,time elapsed 544.51s
+epoch: 26, iter: 13100, avg_loss: 43.7069, kl/H(z|x): 0.0004, recon: 43.7065,time elapsed 545.61s
+epoch: 26, iter: 13150, avg_loss: 43.7116, kl/H(z|x): 0.0004, recon: 43.7113,time elapsed 546.69s
+epoch: 26, iter: 13200, avg_loss: 43.3178, kl/H(z|x): 0.0004, recon: 43.3174,time elapsed 547.82s
+epoch: 26, iter: 13250, avg_loss: 44.7637, kl/H(z|x): 0.0005, recon: 44.7632,time elapsed 548.92s
+epoch: 26, iter: 13300, avg_loss: 44.1573, kl/H(z|x): 0.0004, recon: 44.1569,time elapsed 550.03s
+epoch: 26, iter: 13350, avg_loss: 44.3795, kl/H(z|x): 0.0005, recon: 44.3790,time elapsed 551.11s
+epoch: 26, iter: 13400, avg_loss: 44.2767, kl/H(z|x): 0.0005, recon: 44.2761,time elapsed 552.20s
+epoch: 26, iter: 13450, avg_loss: 43.9566, kl/H(z|x): 0.0004, recon: 43.9561,time elapsed 553.31s
+kl weight 1.0000
+VAL --- avg_loss: 42.8416, kl/H(z|x): 0.0003, mi: 0.0639, recon: 42.8413, nll: 42.8416, ppl: 49.1409
+0 active units
+update best loss
+TEST --- avg_loss: 42.8492, kl/H(z|x): 0.0003, mi: 0.0448, recon: 42.8488, nll: 42.8492, ppl: 49.1747
+epoch: 27, iter: 13500, avg_loss: 44.2243, kl/H(z|x): 0.0003, recon: 44.2240,time elapsed 562.83s
+epoch: 27, iter: 13550, avg_loss: 43.7787, kl/H(z|x): 0.0005, recon: 43.7782,time elapsed 563.92s
+epoch: 27, iter: 13600, avg_loss: 43.8797, kl/H(z|x): 0.0004, recon: 43.8793,time elapsed 565.00s
+epoch: 27, iter: 13650, avg_loss: 44.2306, kl/H(z|x): 0.0006, recon: 44.2301,time elapsed 566.18s
+epoch: 27, iter: 13700, avg_loss: 43.7812, kl/H(z|x): 0.0005, recon: 43.7807,time elapsed 567.29s
+epoch: 27, iter: 13750, avg_loss: 43.6520, kl/H(z|x): 0.0003, recon: 43.6517,time elapsed 568.59s
+epoch: 27, iter: 13800, avg_loss: 44.3062, kl/H(z|x): 0.0004, recon: 44.3058,time elapsed 569.68s
+epoch: 27, iter: 13850, avg_loss: 43.7553, kl/H(z|x): 0.0004, recon: 43.7549,time elapsed 570.80s
+epoch: 27, iter: 13900, avg_loss: 44.5534, kl/H(z|x): 0.0004, recon: 44.5529,time elapsed 571.90s
+epoch: 27, iter: 13950, avg_loss: 44.0106, kl/H(z|x): 0.0004, recon: 44.0102,time elapsed 573.03s
+kl weight 1.0000
+VAL --- avg_loss: 42.8331, kl/H(z|x): 0.0004, mi: 0.1504, recon: 42.8327, nll: 42.8331, ppl: 49.1029
+0 active units
+update best loss
+TEST --- avg_loss: 42.8320, kl/H(z|x): 0.0004, mi: 0.0753, recon: 42.8316, nll: 42.8320, ppl: 49.0980
+epoch: 28, iter: 14000, avg_loss: 39.6822, kl/H(z|x): 0.0005, recon: 39.6818,time elapsed 582.32s
+epoch: 28, iter: 14050, avg_loss: 43.7772, kl/H(z|x): 0.0004, recon: 43.7768,time elapsed 583.51s
+epoch: 28, iter: 14100, avg_loss: 44.1639, kl/H(z|x): 0.0004, recon: 44.1635,time elapsed 584.65s
+epoch: 28, iter: 14150, avg_loss: 43.8640, kl/H(z|x): 0.0004, recon: 43.8636,time elapsed 585.79s
+epoch: 28, iter: 14200, avg_loss: 43.9483, kl/H(z|x): 0.0005, recon: 43.9479,time elapsed 586.92s
+epoch: 28, iter: 14250, avg_loss: 43.9844, kl/H(z|x): 0.0006, recon: 43.9838,time elapsed 588.11s
+epoch: 28, iter: 14300, avg_loss: 44.2249, kl/H(z|x): 0.0005, recon: 44.2243,time elapsed 589.23s
+epoch: 28, iter: 14350, avg_loss: 43.7008, kl/H(z|x): 0.0004, recon: 43.7004,time elapsed 590.35s
+epoch: 28, iter: 14400, avg_loss: 43.8595, kl/H(z|x): 0.0004, recon: 43.8591,time elapsed 591.52s
+epoch: 28, iter: 14450, avg_loss: 44.3218, kl/H(z|x): 0.0004, recon: 44.3214,time elapsed 592.68s
+kl weight 1.0000
+VAL --- avg_loss: 42.8649, kl/H(z|x): 0.0004, mi: -0.0395, recon: 42.8646, nll: 42.8649, ppl: 49.2452
+0 active units
+TEST --- avg_loss: 42.8816, kl/H(z|x): 0.0004, mi: 0.1842, recon: 42.8813, nll: 42.8816, ppl: 49.3200
+epoch: 29, iter: 14500, avg_loss: 43.6188, kl/H(z|x): 0.0003, recon: 43.6184,time elapsed 602.00s
+epoch: 29, iter: 14550, avg_loss: 43.7076, kl/H(z|x): 0.0004, recon: 43.7072,time elapsed 603.14s
+epoch: 29, iter: 14600, avg_loss: 43.7143, kl/H(z|x): 0.0004, recon: 43.7139,time elapsed 604.23s
+epoch: 29, iter: 14650, avg_loss: 44.3535, kl/H(z|x): 0.0004, recon: 44.3531,time elapsed 605.30s
+epoch: 29, iter: 14700, avg_loss: 43.7975, kl/H(z|x): 0.0004, recon: 43.7971,time elapsed 606.38s
+epoch: 29, iter: 14750, avg_loss: 44.0038, kl/H(z|x): 0.0004, recon: 44.0034,time elapsed 607.48s
+epoch: 29, iter: 14800, avg_loss: 44.1614, kl/H(z|x): 0.0004, recon: 44.1611,time elapsed 608.57s
+epoch: 29, iter: 14850, avg_loss: 43.5950, kl/H(z|x): 0.0003, recon: 43.5947,time elapsed 609.66s
+epoch: 29, iter: 14900, avg_loss: 44.0414, kl/H(z|x): 0.0003, recon: 44.0411,time elapsed 610.74s
+epoch: 29, iter: 14950, avg_loss: 43.7423, kl/H(z|x): 0.0005, recon: 43.7418,time elapsed 611.83s
+kl weight 1.0000
+VAL --- avg_loss: 42.7344, kl/H(z|x): 0.0005, mi: 0.1063, recon: 42.7340, nll: 42.7344, ppl: 48.6645
+0 active units
+update best loss
+TEST --- avg_loss: 42.7331, kl/H(z|x): 0.0005, mi: -0.0573, recon: 42.7326, nll: 42.7331, ppl: 48.6584
+epoch: 30, iter: 15000, avg_loss: 43.9892, kl/H(z|x): 0.0005, recon: 43.9888,time elapsed 621.32s
+epoch: 30, iter: 15050, avg_loss: 43.6377, kl/H(z|x): 0.0005, recon: 43.6372,time elapsed 622.45s
+epoch: 30, iter: 15100, avg_loss: 44.1237, kl/H(z|x): 0.0004, recon: 44.1233,time elapsed 623.51s
+epoch: 30, iter: 15150, avg_loss: 43.8119, kl/H(z|x): 0.0005, recon: 43.8115,time elapsed 624.59s
+epoch: 30, iter: 15200, avg_loss: 43.8647, kl/H(z|x): 0.0004, recon: 43.8643,time elapsed 625.70s
+epoch: 30, iter: 15250, avg_loss: 43.9263, kl/H(z|x): 0.0003, recon: 43.9259,time elapsed 626.80s
+epoch: 30, iter: 15300, avg_loss: 44.1386, kl/H(z|x): 0.0004, recon: 44.1382,time elapsed 628.34s
+epoch: 30, iter: 15350, avg_loss: 44.2637, kl/H(z|x): 0.0004, recon: 44.2633,time elapsed 629.57s
+epoch: 30, iter: 15400, avg_loss: 43.6037, kl/H(z|x): 0.0005, recon: 43.6032,time elapsed 630.79s
+epoch: 30, iter: 15450, avg_loss: 43.9861, kl/H(z|x): 0.0004, recon: 43.9857,time elapsed 631.99s
+kl weight 1.0000
+VAL --- avg_loss: 42.8454, kl/H(z|x): 0.0005, mi: -0.0501, recon: 42.8449, nll: 42.8454, ppl: 49.1579
+0 active units
+TEST --- avg_loss: 42.8717, kl/H(z|x): 0.0005, mi: 0.0334, recon: 42.8712, nll: 42.8717, ppl: 49.2755
+epoch: 31, iter: 15500, avg_loss: 49.3075, kl/H(z|x): 0.0005, recon: 49.3070,time elapsed 641.71s
+epoch: 31, iter: 15550, avg_loss: 44.0898, kl/H(z|x): 0.0004, recon: 44.0893,time elapsed 642.82s
+epoch: 31, iter: 15600, avg_loss: 43.8914, kl/H(z|x): 0.0004, recon: 43.8910,time elapsed 643.90s
+epoch: 31, iter: 15650, avg_loss: 43.8410, kl/H(z|x): 0.0004, recon: 43.8405,time elapsed 645.00s
+epoch: 31, iter: 15700, avg_loss: 43.4980, kl/H(z|x): 0.0004, recon: 43.4976,time elapsed 646.09s
+epoch: 31, iter: 15750, avg_loss: 44.1395, kl/H(z|x): 0.0004, recon: 44.1391,time elapsed 647.17s
+epoch: 31, iter: 15800, avg_loss: 43.8627, kl/H(z|x): 0.0004, recon: 43.8623,time elapsed 648.31s
+epoch: 31, iter: 15850, avg_loss: 44.0324, kl/H(z|x): 0.0004, recon: 44.0320,time elapsed 649.39s
+epoch: 31, iter: 15900, avg_loss: 43.8421, kl/H(z|x): 0.0005, recon: 43.8416,time elapsed 650.48s
+epoch: 31, iter: 15950, avg_loss: 44.5026, kl/H(z|x): 0.0004, recon: 44.5022,time elapsed 651.68s
+kl weight 1.0000
+VAL --- avg_loss: 42.8186, kl/H(z|x): 0.0004, mi: -0.0377, recon: 42.8182, nll: 42.8186, ppl: 49.0381
+0 active units
+TEST --- avg_loss: 42.8374, kl/H(z|x): 0.0004, mi: 0.0111, recon: 42.8370, nll: 42.8374, ppl: 49.1222
+epoch: 32, iter: 16000, avg_loss: 43.7023, kl/H(z|x): 0.0004, recon: 43.7019,time elapsed 661.15s
+epoch: 32, iter: 16050, avg_loss: 43.7157, kl/H(z|x): 0.0003, recon: 43.7153,time elapsed 662.24s
+epoch: 32, iter: 16100, avg_loss: 43.2802, kl/H(z|x): 0.0004, recon: 43.2798,time elapsed 663.37s
+epoch: 32, iter: 16150, avg_loss: 44.0081, kl/H(z|x): 0.0004, recon: 44.0077,time elapsed 664.48s
+epoch: 32, iter: 16200, avg_loss: 44.0434, kl/H(z|x): 0.0004, recon: 44.0430,time elapsed 665.55s
+epoch: 32, iter: 16250, avg_loss: 44.1416, kl/H(z|x): 0.0004, recon: 44.1411,time elapsed 666.63s
+epoch: 32, iter: 16300, avg_loss: 43.9179, kl/H(z|x): 0.0005, recon: 43.9174,time elapsed 667.74s
+epoch: 32, iter: 16350, avg_loss: 43.9227, kl/H(z|x): 0.0005, recon: 43.9222,time elapsed 668.83s
+epoch: 32, iter: 16400, avg_loss: 43.8702, kl/H(z|x): 0.0005, recon: 43.8697,time elapsed 669.93s
+epoch: 32, iter: 16450, avg_loss: 43.8965, kl/H(z|x): 0.0006, recon: 43.8959,time elapsed 671.01s
+kl weight 1.0000
+VAL --- avg_loss: 42.8250, kl/H(z|x): 0.0004, mi: 0.0814, recon: 42.8246, nll: 42.8250, ppl: 49.0668
+0 active units
+TEST --- avg_loss: 42.8447, kl/H(z|x): 0.0004, mi: 0.0192, recon: 42.8443, nll: 42.8447, ppl: 49.1547
+epoch: 33, iter: 16500, avg_loss: 44.2517, kl/H(z|x): 0.0004, recon: 44.2513,time elapsed 680.38s
+epoch: 33, iter: 16550, avg_loss: 43.6145, kl/H(z|x): 0.0004, recon: 43.6141,time elapsed 681.48s
+epoch: 33, iter: 16600, avg_loss: 44.2139, kl/H(z|x): 0.0004, recon: 44.2134,time elapsed 682.60s
+epoch: 33, iter: 16650, avg_loss: 43.5590, kl/H(z|x): 0.0005, recon: 43.5585,time elapsed 683.69s
+epoch: 33, iter: 16700, avg_loss: 43.5167, kl/H(z|x): 0.0004, recon: 43.5164,time elapsed 684.78s
+epoch: 33, iter: 16750, avg_loss: 44.3663, kl/H(z|x): 0.0004, recon: 44.3659,time elapsed 685.95s
+epoch: 33, iter: 16800, avg_loss: 43.3762, kl/H(z|x): 0.0005, recon: 43.3757,time elapsed 687.36s
+epoch: 33, iter: 16850, avg_loss: 43.6989, kl/H(z|x): 0.0005, recon: 43.6984,time elapsed 688.73s
+epoch: 33, iter: 16900, avg_loss: 44.0670, kl/H(z|x): 0.0004, recon: 44.0666,time elapsed 689.83s
+epoch: 33, iter: 16950, avg_loss: 44.3268, kl/H(z|x): 0.0004, recon: 44.3263,time elapsed 690.96s
+kl weight 1.0000
+VAL --- avg_loss: 42.7766, kl/H(z|x): 0.0005, mi: 0.0011, recon: 42.7761, nll: 42.7766, ppl: 48.8513
+0 active units
+TEST --- avg_loss: 42.7852, kl/H(z|x): 0.0005, mi: 0.0312, recon: 42.7847, nll: 42.7852, ppl: 48.8897
+epoch: 34, iter: 17000, avg_loss: 45.9578, kl/H(z|x): 0.0005, recon: 45.9573,time elapsed 700.26s
+epoch: 34, iter: 17050, avg_loss: 43.4863, kl/H(z|x): 0.0005, recon: 43.4858,time elapsed 701.34s
+epoch: 34, iter: 17100, avg_loss: 43.1265, kl/H(z|x): 0.0004, recon: 43.1261,time elapsed 702.45s
+epoch: 34, iter: 17150, avg_loss: 44.2183, kl/H(z|x): 0.0004, recon: 44.2179,time elapsed 703.54s
+epoch: 34, iter: 17200, avg_loss: 44.0181, kl/H(z|x): 0.0004, recon: 44.0177,time elapsed 704.61s
+epoch: 34, iter: 17250, avg_loss: 43.8095, kl/H(z|x): 0.0004, recon: 43.8091,time elapsed 705.71s
+epoch: 34, iter: 17300, avg_loss: 43.8077, kl/H(z|x): 0.0005, recon: 43.8073,time elapsed 706.89s
+epoch: 34, iter: 17350, avg_loss: 43.9956, kl/H(z|x): 0.0005, recon: 43.9951,time elapsed 708.04s
+epoch: 34, iter: 17400, avg_loss: 44.3191, kl/H(z|x): 0.0004, recon: 44.3187,time elapsed 709.14s
+epoch: 34, iter: 17450, avg_loss: 43.7856, kl/H(z|x): 0.0004, recon: 43.7852,time elapsed 710.23s
+kl weight 1.0000
+VAL --- avg_loss: 42.8082, kl/H(z|x): 0.0005, mi: -0.0863, recon: 42.8077, nll: 42.8082, ppl: 48.9918
+0 active units
+new lr: 0.250000
+TEST --- avg_loss: 42.7279, kl/H(z|x): 0.0005, mi: -0.0938, recon: 42.7275, nll: 42.7279, ppl: 48.6357
+epoch: 35, iter: 17500, avg_loss: 41.7098, kl/H(z|x): 0.0005, recon: 41.7093,time elapsed 719.87s
+epoch: 35, iter: 17550, avg_loss: 43.6619, kl/H(z|x): 0.0004, recon: 43.6614,time elapsed 720.96s
+epoch: 35, iter: 17600, avg_loss: 43.7670, kl/H(z|x): 0.0004, recon: 43.7666,time elapsed 722.03s
+epoch: 35, iter: 17650, avg_loss: 43.6828, kl/H(z|x): 0.0003, recon: 43.6825,time elapsed 723.16s
+epoch: 35, iter: 17700, avg_loss: 43.8141, kl/H(z|x): 0.0003, recon: 43.8138,time elapsed 724.24s
+epoch: 35, iter: 17750, avg_loss: 43.3714, kl/H(z|x): 0.0003, recon: 43.3711,time elapsed 725.33s
+epoch: 35, iter: 17800, avg_loss: 43.9177, kl/H(z|x): 0.0003, recon: 43.9174,time elapsed 726.41s
+epoch: 35, iter: 17850, avg_loss: 43.1331, kl/H(z|x): 0.0003, recon: 43.1329,time elapsed 727.53s
+epoch: 35, iter: 17900, avg_loss: 43.3851, kl/H(z|x): 0.0002, recon: 43.3849,time elapsed 728.62s
+epoch: 35, iter: 17950, avg_loss: 43.5729, kl/H(z|x): 0.0002, recon: 43.5727,time elapsed 729.68s
+kl weight 1.0000
+VAL --- avg_loss: 42.6877, kl/H(z|x): 0.0002, mi: -0.1456, recon: 42.6875, nll: 42.6877, ppl: 48.4583
+0 active units
+update best loss
+TEST --- avg_loss: 42.6792, kl/H(z|x): 0.0002, mi: -0.0167, recon: 42.6791, nll: 42.6792, ppl: 48.4209
+epoch: 36, iter: 18000, avg_loss: 43.9516, kl/H(z|x): 0.0002, recon: 43.9514,time elapsed 739.03s
+epoch: 36, iter: 18050, avg_loss: 43.7369, kl/H(z|x): 0.0002, recon: 43.7367,time elapsed 740.13s
+epoch: 36, iter: 18100, avg_loss: 43.5138, kl/H(z|x): 0.0003, recon: 43.5136,time elapsed 741.29s
+epoch: 36, iter: 18150, avg_loss: 43.5765, kl/H(z|x): 0.0002, recon: 43.5763,time elapsed 742.49s
+epoch: 36, iter: 18200, avg_loss: 43.2139, kl/H(z|x): 0.0002, recon: 43.2137,time elapsed 743.57s
+epoch: 36, iter: 18250, avg_loss: 43.4710, kl/H(z|x): 0.0002, recon: 43.4709,time elapsed 744.67s
+epoch: 36, iter: 18300, avg_loss: 43.4375, kl/H(z|x): 0.0002, recon: 43.4373,time elapsed 745.75s
+epoch: 36, iter: 18350, avg_loss: 43.5230, kl/H(z|x): 0.0001, recon: 43.5228,time elapsed 746.84s
+epoch: 36, iter: 18400, avg_loss: 43.9662, kl/H(z|x): 0.0001, recon: 43.9661,time elapsed 748.13s
+epoch: 36, iter: 18450, avg_loss: 43.7149, kl/H(z|x): 0.0001, recon: 43.7147,time elapsed 749.24s
+kl weight 1.0000
+VAL --- avg_loss: 42.6654, kl/H(z|x): 0.0002, mi: -0.0094, recon: 42.6652, nll: 42.6654, ppl: 48.3601
+0 active units
+update best loss
+TEST --- avg_loss: 42.6540, kl/H(z|x): 0.0002, mi: -0.0618, recon: 42.6538, nll: 42.6540, ppl: 48.3100
+epoch: 37, iter: 18500, avg_loss: 40.8473, kl/H(z|x): 0.0002, recon: 40.8471,time elapsed 758.54s
+epoch: 37, iter: 18550, avg_loss: 43.1240, kl/H(z|x): 0.0002, recon: 43.1239,time elapsed 759.63s
+epoch: 37, iter: 18600, avg_loss: 43.6882, kl/H(z|x): 0.0002, recon: 43.6880,time elapsed 760.72s
+epoch: 37, iter: 18650, avg_loss: 43.5877, kl/H(z|x): 0.0001, recon: 43.5876,time elapsed 761.81s
+epoch: 37, iter: 18700, avg_loss: 44.0302, kl/H(z|x): 0.0001, recon: 44.0301,time elapsed 762.93s
+epoch: 37, iter: 18750, avg_loss: 43.6260, kl/H(z|x): 0.0001, recon: 43.6258,time elapsed 764.03s
+epoch: 37, iter: 18800, avg_loss: 43.3254, kl/H(z|x): 0.0001, recon: 43.3252,time elapsed 765.13s
+epoch: 37, iter: 18850, avg_loss: 43.7614, kl/H(z|x): 0.0002, recon: 43.7612,time elapsed 766.22s
+epoch: 37, iter: 18900, avg_loss: 43.4881, kl/H(z|x): 0.0002, recon: 43.4879,time elapsed 767.32s
+epoch: 37, iter: 18950, avg_loss: 43.3303, kl/H(z|x): 0.0002, recon: 43.3301,time elapsed 768.47s
+kl weight 1.0000
+VAL --- avg_loss: 42.6456, kl/H(z|x): 0.0002, mi: -0.0024, recon: 42.6454, nll: 42.6456, ppl: 48.2729
+0 active units
+update best loss
+TEST --- avg_loss: 42.6398, kl/H(z|x): 0.0002, mi: -0.0623, recon: 42.6396, nll: 42.6398, ppl: 48.2474
+epoch: 38, iter: 19000, avg_loss: 39.2798, kl/H(z|x): 0.0002, recon: 39.2796,time elapsed 777.94s
+epoch: 38, iter: 19050, avg_loss: 43.5151, kl/H(z|x): 0.0002, recon: 43.5150,time elapsed 779.02s
+epoch: 38, iter: 19100, avg_loss: 43.5278, kl/H(z|x): 0.0001, recon: 43.5277,time elapsed 780.12s
+epoch: 38, iter: 19150, avg_loss: 43.0568, kl/H(z|x): 0.0001, recon: 43.0567,time elapsed 781.20s
+epoch: 38, iter: 19200, avg_loss: 43.2855, kl/H(z|x): 0.0001, recon: 43.2854,time elapsed 782.34s
+epoch: 38, iter: 19250, avg_loss: 43.9021, kl/H(z|x): 0.0001, recon: 43.9020,time elapsed 783.41s
+epoch: 38, iter: 19300, avg_loss: 43.6581, kl/H(z|x): 0.0001, recon: 43.6580,time elapsed 784.47s
+epoch: 38, iter: 19350, avg_loss: 43.7079, kl/H(z|x): 0.0001, recon: 43.7078,time elapsed 785.55s
+epoch: 38, iter: 19400, avg_loss: 43.3235, kl/H(z|x): 0.0001, recon: 43.3233,time elapsed 786.63s
+epoch: 38, iter: 19450, avg_loss: 43.6018, kl/H(z|x): 0.0001, recon: 43.6017,time elapsed 787.74s
+kl weight 1.0000
+VAL --- avg_loss: 42.6433, kl/H(z|x): 0.0001, mi: 0.0521, recon: 42.6432, nll: 42.6433, ppl: 48.2628
+0 active units
+update best loss
+TEST --- avg_loss: 42.6409, kl/H(z|x): 0.0001, mi: -0.0322, recon: 42.6408, nll: 42.6409, ppl: 48.2525
+epoch: 39, iter: 19500, avg_loss: 42.6294, kl/H(z|x): 0.0001, recon: 42.6293,time elapsed 797.07s
+epoch: 39, iter: 19550, avg_loss: 43.3645, kl/H(z|x): 0.0001, recon: 43.3644,time elapsed 798.21s
+epoch: 39, iter: 19600, avg_loss: 43.5161, kl/H(z|x): 0.0002, recon: 43.5160,time elapsed 799.30s
+epoch: 39, iter: 19650, avg_loss: 43.6186, kl/H(z|x): 0.0001, recon: 43.6185,time elapsed 800.42s
+epoch: 39, iter: 19700, avg_loss: 43.1040, kl/H(z|x): 0.0001, recon: 43.1039,time elapsed 801.51s
+epoch: 39, iter: 19750, avg_loss: 43.1631, kl/H(z|x): 0.0001, recon: 43.1630,time elapsed 802.64s
+epoch: 39, iter: 19800, avg_loss: 43.8840, kl/H(z|x): 0.0001, recon: 43.8839,time elapsed 803.72s
+epoch: 39, iter: 19850, avg_loss: 43.0590, kl/H(z|x): 0.0002, recon: 43.0588,time elapsed 804.80s
+epoch: 39, iter: 19900, avg_loss: 43.7155, kl/H(z|x): 0.0001, recon: 43.7154,time elapsed 805.88s
+epoch: 39, iter: 19950, avg_loss: 43.7701, kl/H(z|x): 0.0001, recon: 43.7700,time elapsed 807.06s
+kl weight 1.0000
+VAL --- avg_loss: 42.6654, kl/H(z|x): 0.0001, mi: -0.0592, recon: 42.6653, nll: 42.6654, ppl: 48.3601
+0 active units
+TEST --- avg_loss: 42.6620, kl/H(z|x): 0.0001, mi: 0.0123, recon: 42.6618, nll: 42.6620, ppl: 48.3449
+epoch: 40, iter: 20000, avg_loss: 40.8213, kl/H(z|x): 0.0001, recon: 40.8212,time elapsed 816.66s
+epoch: 40, iter: 20050, avg_loss: 43.6310, kl/H(z|x): 0.0001, recon: 43.6309,time elapsed 817.78s
+epoch: 40, iter: 20100, avg_loss: 43.3006, kl/H(z|x): 0.0001, recon: 43.3004,time elapsed 818.86s
+epoch: 40, iter: 20150, avg_loss: 43.5762, kl/H(z|x): 0.0001, recon: 43.5761,time elapsed 819.95s
+epoch: 40, iter: 20200, avg_loss: 43.1754, kl/H(z|x): 0.0001, recon: 43.1753,time elapsed 821.04s
+epoch: 40, iter: 20250, avg_loss: 43.1876, kl/H(z|x): 0.0001, recon: 43.1875,time elapsed 822.11s
+epoch: 40, iter: 20300, avg_loss: 43.6291, kl/H(z|x): 0.0001, recon: 43.6290,time elapsed 823.24s
+epoch: 40, iter: 20350, avg_loss: 43.5045, kl/H(z|x): 0.0001, recon: 43.5044,time elapsed 824.33s
+epoch: 40, iter: 20400, avg_loss: 43.7003, kl/H(z|x): 0.0001, recon: 43.7002,time elapsed 825.44s
+epoch: 40, iter: 20450, avg_loss: 43.0749, kl/H(z|x): 0.0001, recon: 43.0748,time elapsed 826.51s
+kl weight 1.0000
+VAL --- avg_loss: 42.6483, kl/H(z|x): 0.0001, mi: 0.1020, recon: 42.6482, nll: 42.6483, ppl: 48.2851
+0 active units
+TEST --- avg_loss: 42.6362, kl/H(z|x): 0.0001, mi: 0.0486, recon: 42.6361, nll: 42.6362, ppl: 48.2319
+epoch: 41, iter: 20500, avg_loss: 46.2897, kl/H(z|x): 0.0001, recon: 46.2896,time elapsed 836.49s
+epoch: 41, iter: 20550, avg_loss: 43.3702, kl/H(z|x): 0.0001, recon: 43.3701,time elapsed 837.65s
+epoch: 41, iter: 20600, avg_loss: 43.3133, kl/H(z|x): 0.0001, recon: 43.3132,time elapsed 838.74s
+epoch: 41, iter: 20650, avg_loss: 43.0952, kl/H(z|x): 0.0001, recon: 43.0951,time elapsed 839.82s
+epoch: 41, iter: 20700, avg_loss: 43.4714, kl/H(z|x): 0.0001, recon: 43.4712,time elapsed 840.93s
+epoch: 41, iter: 20750, avg_loss: 43.2933, kl/H(z|x): 0.0001, recon: 43.2931,time elapsed 842.23s
+epoch: 41, iter: 20800, avg_loss: 43.1513, kl/H(z|x): 0.0001, recon: 43.1511,time elapsed 843.60s
+epoch: 41, iter: 20850, avg_loss: 43.4006, kl/H(z|x): 0.0002, recon: 43.4004,time elapsed 845.02s
+epoch: 41, iter: 20900, avg_loss: 43.9970, kl/H(z|x): 0.0002, recon: 43.9968,time elapsed 846.14s
+epoch: 41, iter: 20950, avg_loss: 43.5226, kl/H(z|x): 0.0002, recon: 43.5224,time elapsed 847.26s
+kl weight 1.0000
+VAL --- avg_loss: 42.6132, kl/H(z|x): 0.0001, mi: -0.0262, recon: 42.6130, nll: 42.6132, ppl: 48.1309
+0 active units
+update best loss
+TEST --- avg_loss: 42.6104, kl/H(z|x): 0.0001, mi: 0.0578, recon: 42.6103, nll: 42.6104, ppl: 48.1187
+epoch: 42, iter: 21000, avg_loss: 45.0442, kl/H(z|x): 0.0001, recon: 45.0440,time elapsed 856.74s
+epoch: 42, iter: 21050, avg_loss: 43.2709, kl/H(z|x): 0.0001, recon: 43.2708,time elapsed 857.93s
+epoch: 42, iter: 21100, avg_loss: 43.5428, kl/H(z|x): 0.0001, recon: 43.5427,time elapsed 859.10s
+epoch: 42, iter: 21150, avg_loss: 43.0828, kl/H(z|x): 0.0001, recon: 43.0827,time elapsed 860.20s
+epoch: 42, iter: 21200, avg_loss: 43.4841, kl/H(z|x): 0.0001, recon: 43.4839,time elapsed 861.33s
+epoch: 42, iter: 21250, avg_loss: 43.1936, kl/H(z|x): 0.0001, recon: 43.1935,time elapsed 862.47s
+epoch: 42, iter: 21300, avg_loss: 43.5686, kl/H(z|x): 0.0001, recon: 43.5684,time elapsed 863.68s
+epoch: 42, iter: 21350, avg_loss: 43.8472, kl/H(z|x): 0.0001, recon: 43.8471,time elapsed 864.78s
+epoch: 42, iter: 21400, avg_loss: 43.1563, kl/H(z|x): 0.0001, recon: 43.1562,time elapsed 865.87s
+epoch: 42, iter: 21450, avg_loss: 43.3632, kl/H(z|x): 0.0001, recon: 43.3630,time elapsed 866.99s
+kl weight 1.0000
+VAL --- avg_loss: 42.6019, kl/H(z|x): 0.0001, mi: 0.0535, recon: 42.6018, nll: 42.6019, ppl: 48.0818
+0 active units
+update best loss
+TEST --- avg_loss: 42.6078, kl/H(z|x): 0.0001, mi: 0.0234, recon: 42.6077, nll: 42.6078, ppl: 48.1076
+epoch: 43, iter: 21500, avg_loss: 42.9676, kl/H(z|x): 0.0001, recon: 42.9675,time elapsed 876.50s
+epoch: 43, iter: 21550, avg_loss: 43.4092, kl/H(z|x): 0.0001, recon: 43.4091,time elapsed 877.58s
+epoch: 43, iter: 21600, avg_loss: 43.5033, kl/H(z|x): 0.0001, recon: 43.5032,time elapsed 878.69s
+epoch: 43, iter: 21650, avg_loss: 43.5289, kl/H(z|x): 0.0002, recon: 43.5287,time elapsed 879.77s
+epoch: 43, iter: 21700, avg_loss: 42.9456, kl/H(z|x): 0.0002, recon: 42.9455,time elapsed 880.87s
+epoch: 43, iter: 21750, avg_loss: 43.5842, kl/H(z|x): 0.0002, recon: 43.5841,time elapsed 881.98s
+epoch: 43, iter: 21800, avg_loss: 43.6099, kl/H(z|x): 0.0001, recon: 43.6098,time elapsed 883.08s
+epoch: 43, iter: 21850, avg_loss: 42.9298, kl/H(z|x): 0.0001, recon: 42.9297,time elapsed 884.16s
+epoch: 43, iter: 21900, avg_loss: 43.4812, kl/H(z|x): 0.0001, recon: 43.4811,time elapsed 885.24s
+epoch: 43, iter: 21950, avg_loss: 43.5893, kl/H(z|x): 0.0001, recon: 43.5892,time elapsed 886.33s
+kl weight 1.0000
+VAL --- avg_loss: 42.5868, kl/H(z|x): 0.0002, mi: -0.0469, recon: 42.5866, nll: 42.5868, ppl: 48.0157
+0 active units
+update best loss
+TEST --- avg_loss: 42.5918, kl/H(z|x): 0.0002, mi: -0.0134, recon: 42.5916, nll: 42.5918, ppl: 48.0376
+epoch: 44, iter: 22000, avg_loss: 41.1623, kl/H(z|x): 0.0002, recon: 41.1621,time elapsed 896.15s
+epoch: 44, iter: 22050, avg_loss: 43.6225, kl/H(z|x): 0.0002, recon: 43.6223,time elapsed 897.24s
+epoch: 44, iter: 22100, avg_loss: 43.0115, kl/H(z|x): 0.0001, recon: 43.0113,time elapsed 898.35s
+epoch: 44, iter: 22150, avg_loss: 43.8803, kl/H(z|x): 0.0001, recon: 43.8802,time elapsed 899.44s
+epoch: 44, iter: 22200, avg_loss: 43.7196, kl/H(z|x): 0.0001, recon: 43.7195,time elapsed 900.54s
+epoch: 44, iter: 22250, avg_loss: 43.1994, kl/H(z|x): 0.0001, recon: 43.1993,time elapsed 901.65s
+epoch: 44, iter: 22300, avg_loss: 43.2972, kl/H(z|x): 0.0001, recon: 43.2971,time elapsed 902.77s
+epoch: 44, iter: 22350, avg_loss: 43.2614, kl/H(z|x): 0.0001, recon: 43.2613,time elapsed 903.85s
+epoch: 44, iter: 22400, avg_loss: 43.0387, kl/H(z|x): 0.0001, recon: 43.0386,time elapsed 904.93s
+epoch: 44, iter: 22450, avg_loss: 43.5179, kl/H(z|x): 0.0001, recon: 43.5178,time elapsed 906.00s
+kl weight 1.0000
+VAL --- avg_loss: 42.6205, kl/H(z|x): 0.0001, mi: -0.0640, recon: 42.6204, nll: 42.6205, ppl: 48.1632
+0 active units
+TEST --- avg_loss: 42.6170, kl/H(z|x): 0.0001, mi: 0.0838, recon: 42.6168, nll: 42.6170, ppl: 48.1476
+epoch: 45, iter: 22500, avg_loss: 45.8301, kl/H(z|x): 0.0001, recon: 45.8299,time elapsed 915.37s
+epoch: 45, iter: 22550, avg_loss: 43.7427, kl/H(z|x): 0.0001, recon: 43.7426,time elapsed 916.46s
+epoch: 45, iter: 22600, avg_loss: 43.1537, kl/H(z|x): 0.0001, recon: 43.1536,time elapsed 917.58s
+epoch: 45, iter: 22650, avg_loss: 42.6766, kl/H(z|x): 0.0001, recon: 42.6765,time elapsed 918.67s
+epoch: 45, iter: 22700, avg_loss: 43.2644, kl/H(z|x): 0.0001, recon: 43.2643,time elapsed 919.77s
+epoch: 45, iter: 22750, avg_loss: 43.8260, kl/H(z|x): 0.0001, recon: 43.8259,time elapsed 920.87s
+epoch: 45, iter: 22800, avg_loss: 43.0574, kl/H(z|x): 0.0001, recon: 43.0573,time elapsed 922.02s
+epoch: 45, iter: 22850, avg_loss: 43.5961, kl/H(z|x): 0.0001, recon: 43.5960,time elapsed 923.16s
+epoch: 45, iter: 22900, avg_loss: 43.7173, kl/H(z|x): 0.0001, recon: 43.7172,time elapsed 924.25s
+epoch: 45, iter: 22950, avg_loss: 43.1578, kl/H(z|x): 0.0002, recon: 43.1577,time elapsed 925.31s
+kl weight 1.0000
+VAL --- avg_loss: 42.6148, kl/H(z|x): 0.0001, mi: -0.0306, recon: 42.6146, nll: 42.6148, ppl: 48.1380
+0 active units
+TEST --- avg_loss: 42.6229, kl/H(z|x): 0.0001, mi: 0.0225, recon: 42.6227, nll: 42.6229, ppl: 48.1734
+epoch: 46, iter: 23000, avg_loss: 45.8865, kl/H(z|x): 0.0001, recon: 45.8864,time elapsed 935.17s
+epoch: 46, iter: 23050, avg_loss: 42.8259, kl/H(z|x): 0.0002, recon: 42.8257,time elapsed 936.26s
+epoch: 46, iter: 23100, avg_loss: 43.3159, kl/H(z|x): 0.0001, recon: 43.3158,time elapsed 937.35s
+epoch: 46, iter: 23150, avg_loss: 43.1671, kl/H(z|x): 0.0001, recon: 43.1670,time elapsed 938.47s
+epoch: 46, iter: 23200, avg_loss: 43.3078, kl/H(z|x): 0.0001, recon: 43.3077,time elapsed 939.55s
+epoch: 46, iter: 23250, avg_loss: 43.6324, kl/H(z|x): 0.0001, recon: 43.6323,time elapsed 940.64s
+epoch: 46, iter: 23300, avg_loss: 43.4459, kl/H(z|x): 0.0001, recon: 43.4458,time elapsed 941.71s
+epoch: 46, iter: 23350, avg_loss: 43.8197, kl/H(z|x): 0.0002, recon: 43.8195,time elapsed 942.83s
+epoch: 46, iter: 23400, avg_loss: 43.6614, kl/H(z|x): 0.0001, recon: 43.6613,time elapsed 943.92s
+epoch: 46, iter: 23450, avg_loss: 43.3612, kl/H(z|x): 0.0001, recon: 43.3610,time elapsed 945.02s
+kl weight 1.0000
+VAL --- avg_loss: 42.5967, kl/H(z|x): 0.0001, mi: -0.0161, recon: 42.5966, nll: 42.5967, ppl: 48.0588
+0 active units
+TEST --- avg_loss: 42.6032, kl/H(z|x): 0.0001, mi: -0.0448, recon: 42.6031, nll: 42.6032, ppl: 48.0873
+epoch: 47, iter: 23500, avg_loss: 44.5862, kl/H(z|x): 0.0001, recon: 44.5860,time elapsed 954.40s
+epoch: 47, iter: 23550, avg_loss: 43.0005, kl/H(z|x): 0.0001, recon: 43.0004,time elapsed 955.51s
+epoch: 47, iter: 23600, avg_loss: 43.6972, kl/H(z|x): 0.0001, recon: 43.6971,time elapsed 956.60s
+epoch: 47, iter: 23650, avg_loss: 43.2438, kl/H(z|x): 0.0001, recon: 43.2437,time elapsed 957.72s
+epoch: 47, iter: 23700, avg_loss: 43.3149, kl/H(z|x): 0.0001, recon: 43.3147,time elapsed 958.81s
+epoch: 47, iter: 23750, avg_loss: 43.3793, kl/H(z|x): 0.0001, recon: 43.3792,time elapsed 959.92s
+epoch: 47, iter: 23800, avg_loss: 43.4107, kl/H(z|x): 0.0002, recon: 43.4105,time elapsed 960.99s
+epoch: 47, iter: 23850, avg_loss: 43.8430, kl/H(z|x): 0.0001, recon: 43.8428,time elapsed 962.09s
+epoch: 47, iter: 23900, avg_loss: 42.8022, kl/H(z|x): 0.0001, recon: 42.8021,time elapsed 963.20s
+epoch: 47, iter: 23950, avg_loss: 43.4718, kl/H(z|x): 0.0001, recon: 43.4717,time elapsed 964.31s
+kl weight 1.0000
+VAL --- avg_loss: 42.5876, kl/H(z|x): 0.0001, mi: -0.0421, recon: 42.5875, nll: 42.5876, ppl: 48.0192
+0 active units
+TEST --- avg_loss: 42.5844, kl/H(z|x): 0.0001, mi: -0.0502, recon: 42.5842, nll: 42.5844, ppl: 48.0050
+epoch: 48, iter: 24000, avg_loss: 44.4487, kl/H(z|x): 0.0001, recon: 44.4486,time elapsed 973.63s
+epoch: 48, iter: 24050, avg_loss: 42.7031, kl/H(z|x): 0.0001, recon: 42.7030,time elapsed 974.70s
+epoch: 48, iter: 24100, avg_loss: 43.6009, kl/H(z|x): 0.0001, recon: 43.6008,time elapsed 975.78s
+epoch: 48, iter: 24150, avg_loss: 43.3372, kl/H(z|x): 0.0001, recon: 43.3371,time elapsed 976.87s
+epoch: 48, iter: 24200, avg_loss: 43.4364, kl/H(z|x): 0.0001, recon: 43.4363,time elapsed 978.00s
+epoch: 48, iter: 24250, avg_loss: 43.0002, kl/H(z|x): 0.0001, recon: 43.0001,time elapsed 979.08s
+epoch: 48, iter: 24300, avg_loss: 43.3012, kl/H(z|x): 0.0001, recon: 43.3010,time elapsed 980.16s
+epoch: 48, iter: 24350, avg_loss: 43.4798, kl/H(z|x): 0.0001, recon: 43.4796,time elapsed 981.24s
+epoch: 48, iter: 24400, avg_loss: 43.6218, kl/H(z|x): 0.0001, recon: 43.6216,time elapsed 982.42s
+epoch: 48, iter: 24450, avg_loss: 43.2860, kl/H(z|x): 0.0001, recon: 43.2858,time elapsed 983.52s
+kl weight 1.0000
+VAL --- avg_loss: 42.5813, kl/H(z|x): 0.0001, mi: 0.0251, recon: 42.5811, nll: 42.5813, ppl: 47.9915
+0 active units
+update best loss
+TEST --- avg_loss: 42.5769, kl/H(z|x): 0.0001, mi: 0.0155, recon: 42.5768, nll: 42.5769, ppl: 47.9726
+epoch: 49, iter: 24500, avg_loss: 42.8160, kl/H(z|x): 0.0001, recon: 42.8159,time elapsed 992.94s
+epoch: 49, iter: 24550, avg_loss: 43.3220, kl/H(z|x): 0.0001, recon: 43.3219,time elapsed 994.02s
+epoch: 49, iter: 24600, avg_loss: 43.2801, kl/H(z|x): 0.0001, recon: 43.2799,time elapsed 995.12s
+epoch: 49, iter: 24650, avg_loss: 43.4456, kl/H(z|x): 0.0001, recon: 43.4455,time elapsed 996.21s
+epoch: 49, iter: 24700, avg_loss: 43.3676, kl/H(z|x): 0.0001, recon: 43.3675,time elapsed 997.28s
+epoch: 49, iter: 24750, avg_loss: 42.8345, kl/H(z|x): 0.0001, recon: 42.8343,time elapsed 998.40s
+epoch: 49, iter: 24800, avg_loss: 43.6380, kl/H(z|x): 0.0001, recon: 43.6378,time elapsed 999.48s
+epoch: 49, iter: 24850, avg_loss: 43.1312, kl/H(z|x): 0.0001, recon: 43.1311,time elapsed 1000.56s
+epoch: 49, iter: 24900, avg_loss: 43.2566, kl/H(z|x): 0.0001, recon: 43.2565,time elapsed 1001.64s
+epoch: 49, iter: 24950, avg_loss: 43.8295, kl/H(z|x): 0.0001, recon: 43.8293,time elapsed 1002.77s
+kl weight 1.0000
+VAL --- avg_loss: 42.6173, kl/H(z|x): 0.0001, mi: 0.0351, recon: 42.6171, nll: 42.6173, ppl: 48.1488
+0 active units
+TEST --- avg_loss: 42.6258, kl/H(z|x): 0.0001, mi: -0.1673, recon: 42.6256, nll: 42.6258, ppl: 48.1860
+TEST --- avg_loss: 42.5708, kl/H(z|x): 0.0001, mi: 0.0651, recon: 42.5707, nll: 42.5708, ppl: 47.9457
+0 active units
+iw nll computing 00%
+iw nll computing 10%
+iw nll computing 20%
+iw nll computing 30%
+iw nll computing 40%
+iw nll computing 50%
+iw nll computing 60%
+iw nll computing 70%
+iw nll computing 80%
+iw nll computing 90%
+iw nll: 42.5245, iw ppl: 47.7445
diff --git a/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt
new file mode 100644
index 0000000..6b00894
Binary files /dev/null and b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0/model.pt differ
diff --git a/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt
new file mode 100644
index 0000000..f53c3c4
--- /dev/null
+++ b/models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt
@@ -0,0 +1,5 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=0.0, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/synthetic/syntheticKL1.00_dr0.00_nz32_0_0_783435_lr1.0_betaF_5/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
diff --git a/models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
new file mode 100644
index 0000000..6794a71
--- /dev/null
+++ b/models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
@@ -0,0 +1,5 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=0.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/synthetic/synthetic_KL0.00_warm100_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
diff --git a/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
new file mode 100644
index 0000000..8cef7ec
--- /dev/null
+++ b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
@@ -0,0 +1,7 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=0.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=10)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 91.9735, kl/H(z|x): 15.9433, mi: 0.0662, recon: 76.0302,au 0, time elapsed 4.20s
+epoch: 0, iter: 50, avg_loss: 81.7852, kl/H(z|x): 17.6312, mi: 3.5894, recon: 64.1539,au 32, time elapsed 9.39s
diff --git a/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt
new file mode 100644
index 0000000..8f8f059
Binary files /dev/null and b/models/synthetic/synthetic_KL0.00_warm10_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt differ
diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt
new file mode 100644
index 0000000..f325a1c
--- /dev/null
+++ b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt
@@ -0,0 +1,162 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, reset_dec=False, save_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 91.9838, kl/H(z|x): 15.9542, mi: 0.0727, recon: 76.0295,au 0, time elapsed 3.89s
+gamma Parameter containing:
+tensor([0.6361, 0.8692, 1.0447, 1.0875, 0.8834, 1.1141, 0.9036, 0.9981, 0.7376,
+ 1.0701, 1.0894, 0.8052, 1.0607, 0.9636, 0.8960, 1.0999, 1.2594, 1.0827,
+ 0.9545, 1.0415, 0.8483, 0.7461, 0.9029, 1.0743, 0.9385, 1.0999, 0.7905,
+ 1.1520, 1.0535, 1.0816, 1.3140, 1.0453], requires_grad=True)
+train loc mean tensor([-0.0733, -0.0924, -0.0358, -0.0838, -0.0454, 0.0331, -0.0310, 0.0068,
+ -0.0078, 0.0136, -0.0392, -0.0524, 0.0494, 0.0408, -0.0388, -0.0561,
+ -0.0274, -0.0274, 0.0320, -0.1015, 0.0355, 0.0182, 0.0461, 0.0107,
+ 0.0683, -0.0382, -0.0352, -0.0434, 0.0076, 0.0446, -0.0125, -0.0292])
+train scale std tensor([7.7909e-05, 7.0655e-05, 8.6690e-05, 3.9981e-05, 5.8483e-05, 5.1343e-05,
+ 5.1193e-05, 7.1845e-05, 4.4821e-05, 4.0294e-05, 2.5077e-05, 5.2438e-05,
+ 3.7955e-05, 6.1363e-05, 5.4091e-05, 4.8194e-05, 7.3150e-05, 8.1629e-05,
+ 4.1456e-05, 6.3313e-05, 8.4339e-05, 2.3192e-05, 9.8620e-05, 2.6473e-05,
+ 6.0505e-05, 4.4957e-05, 3.8171e-05, 7.3833e-05, 3.8787e-05, 5.0260e-05,
+ 3.9497e-05, 5.0977e-05])
+train scale mean tensor([1.0283, 1.0284, 1.0290, 1.0281, 1.0284, 1.0283, 1.0283, 1.0281, 1.0282,
+ 1.0285, 1.0286, 1.0284, 1.0279, 1.0284, 1.0276, 1.0282, 1.0295, 1.0288,
+ 1.0292, 1.0284, 1.0287, 1.0285, 1.0282, 1.0289, 1.0293, 1.0279, 1.0287,
+ 1.0282, 1.0281, 1.0293, 1.0288, 1.0287])
+epoch: 0, iter: 50, avg_loss: 68.9229, kl/H(z|x): 1.8562, mi: -0.0067, recon: 67.0667,au 0, time elapsed 8.64s
+epoch: 0, iter: 100, avg_loss: 61.1798, kl/H(z|x): 0.2321, mi: 0.0196, recon: 60.9478,au 0, time elapsed 13.72s
+epoch: 0, iter: 150, avg_loss: 60.0701, kl/H(z|x): 0.3995, mi: -0.1306, recon: 59.6706,au 0, time elapsed 19.49s
+epoch: 0, iter: 200, avg_loss: 57.2679, kl/H(z|x): 0.1905, mi: -0.0269, recon: 57.0774,au 0, time elapsed 24.24s
+epoch: 0, iter: 250, avg_loss: 54.3758, kl/H(z|x): 0.1566, mi: -0.0273, recon: 54.2192,au 0, time elapsed 28.84s
+epoch: 0, iter: 300, avg_loss: 52.8017, kl/H(z|x): 0.1288, mi: -0.0506, recon: 52.6729,au 0, time elapsed 33.81s
+epoch: 0, iter: 350, avg_loss: 51.9674, kl/H(z|x): 0.1882, mi: 0.0187, recon: 51.7792,au 0, time elapsed 38.53s
+epoch: 0, iter: 400, avg_loss: 51.5230, kl/H(z|x): 0.1526, mi: -0.0481, recon: 51.3704,au 0, time elapsed 43.18s
+epoch: 0, iter: 450, avg_loss: 51.2951, kl/H(z|x): 0.1565, mi: 0.0256, recon: 51.1385,au 0, time elapsed 47.97s
+kl weight 1.0000
+VAL --- avg_loss: 48.1727, kl/H(z|x): 0.1342, mi: -0.2018, recon: 48.0385, nll: 48.1727, ppl: 79.7848
+0 active units
+update best loss
+TEST --- avg_loss: 48.2219, kl/H(z|x): 0.1342, mi: -0.0120, recon: 48.0877, nll: 48.2219, ppl: 80.1426
+epoch: 1, iter: 500, avg_loss: 47.6891, kl/H(z|x): 0.1346, recon: 47.5546,time elapsed 62.33s
+epoch: 1, iter: 550, avg_loss: 50.0957, kl/H(z|x): 0.1378, recon: 49.9579,time elapsed 63.77s
+epoch: 1, iter: 600, avg_loss: 50.4986, kl/H(z|x): 0.1395, recon: 50.3591,time elapsed 65.03s
+epoch: 1, iter: 650, avg_loss: 49.5509, kl/H(z|x): 0.1456, recon: 49.4053,time elapsed 66.35s
+epoch: 1, iter: 700, avg_loss: 49.5276, kl/H(z|x): 0.2423, recon: 49.2853,time elapsed 67.58s
+epoch: 1, iter: 750, avg_loss: 49.0258, kl/H(z|x): 0.1555, recon: 48.8703,time elapsed 69.08s
+epoch: 1, iter: 800, avg_loss: 48.6905, kl/H(z|x): 0.1673, recon: 48.5232,time elapsed 70.86s
+epoch: 1, iter: 850, avg_loss: 48.3193, kl/H(z|x): 0.1522, recon: 48.1671,time elapsed 72.32s
+epoch: 1, iter: 900, avg_loss: 48.5319, kl/H(z|x): 0.1697, recon: 48.3622,time elapsed 73.62s
+epoch: 1, iter: 950, avg_loss: 47.5075, kl/H(z|x): 0.1012, recon: 47.4063,time elapsed 74.94s
+kl weight 1.0000
+VAL --- avg_loss: 45.2795, kl/H(z|x): 0.0738, mi: 0.0029, recon: 45.2057, nll: 45.2795, ppl: 61.3332
+0 active units
+update best loss
+TEST --- avg_loss: 45.2758, kl/H(z|x): 0.0738, mi: -0.0522, recon: 45.2020, nll: 45.2758, ppl: 61.3126
+epoch: 2, iter: 1000, avg_loss: 45.3208, kl/H(z|x): 0.0742, recon: 45.2465,time elapsed 87.48s
+epoch: 2, iter: 1050, avg_loss: 47.0271, kl/H(z|x): 0.0944, recon: 46.9327,time elapsed 89.00s
+epoch: 2, iter: 1100, avg_loss: 47.5806, kl/H(z|x): 0.2097, recon: 47.3709,time elapsed 91.38s
+epoch: 2, iter: 1150, avg_loss: 47.6823, kl/H(z|x): 0.1700, recon: 47.5123,time elapsed 92.75s
+epoch: 2, iter: 1200, avg_loss: 47.0332, kl/H(z|x): 0.1020, recon: 46.9313,time elapsed 93.91s
+epoch: 2, iter: 1250, avg_loss: 47.2249, kl/H(z|x): 0.0633, recon: 47.1616,time elapsed 95.02s
+epoch: 2, iter: 1300, avg_loss: 46.9512, kl/H(z|x): 0.1286, recon: 46.8226,time elapsed 96.24s
+epoch: 2, iter: 1350, avg_loss: 46.9885, kl/H(z|x): 0.1035, recon: 46.8850,time elapsed 97.37s
+epoch: 2, iter: 1400, avg_loss: 46.3206, kl/H(z|x): 0.1372, recon: 46.1834,time elapsed 98.49s
+epoch: 2, iter: 1450, avg_loss: 47.1521, kl/H(z|x): 0.1887, recon: 46.9634,time elapsed 99.60s
+kl weight 1.0000
+VAL --- avg_loss: 44.4193, kl/H(z|x): 0.0622, mi: 0.0397, recon: 44.3571, nll: 44.4193, ppl: 56.7194
+0 active units
+update best loss
+TEST --- avg_loss: 44.4166, kl/H(z|x): 0.0622, mi: -0.0181, recon: 44.3544, nll: 44.4166, ppl: 56.7056
+epoch: 3, iter: 1500, avg_loss: 45.3653, kl/H(z|x): 0.0619, recon: 45.3034,time elapsed 109.59s
+epoch: 3, iter: 1550, avg_loss: 46.6098, kl/H(z|x): 0.0684, recon: 46.5414,time elapsed 110.68s
+epoch: 3, iter: 1600, avg_loss: 46.6580, kl/H(z|x): 0.0872, recon: 46.5708,time elapsed 111.78s
+epoch: 3, iter: 1650, avg_loss: 46.1828, kl/H(z|x): 0.0844, recon: 46.0984,time elapsed 113.02s
+epoch: 3, iter: 1700, avg_loss: 46.3273, kl/H(z|x): 0.0641, recon: 46.2631,time elapsed 114.24s
+epoch: 3, iter: 1750, avg_loss: 46.4725, kl/H(z|x): 0.0632, recon: 46.4093,time elapsed 115.38s
+epoch: 3, iter: 1800, avg_loss: 46.6433, kl/H(z|x): 0.1000, recon: 46.5433,time elapsed 116.50s
+epoch: 3, iter: 1850, avg_loss: 46.1040, kl/H(z|x): 0.0918, recon: 46.0121,time elapsed 117.60s
+gamma Parameter containing:
+tensor([0.7547, 0.8856, 1.1843, 1.0126, 1.0597, 0.9106, 0.9730, 1.1779, 0.7441,
+ 1.1986, 0.9274, 0.9449, 1.0565, 0.5328, 0.8193, 1.2728, 1.2570, 0.8743,
+ 0.9534, 0.7917, 0.8276, 0.8855, 1.0832, 1.0849, 1.1097, 1.1932, 0.7251,
+ 1.1920, 1.1917, 0.9867, 0.9388, 0.9593], requires_grad=True)
+train loc mean tensor([ 0.0197, -0.0169, -0.0196, -0.0653, -0.0198, 0.0277, -0.0709, 0.0039,
+ 0.0091, -0.0372, -0.0042, -0.0662, 0.0050, 0.0271, 0.0054, 0.0648,
+ 0.0158, -0.1142, 0.0130, -0.0441, -0.0233, 0.0194, 0.0196, -0.0130,
+ 0.0155, 0.0078, 0.0079, 0.0202, 0.0312, -0.0052, -0.0383, -0.0199],
+ grad_fn=)
+train scale std tensor([1.0685e-05, 8.4577e-06, 1.6077e-05, 1.5295e-05, 1.4061e-05, 1.4010e-05,
+ 2.2436e-05, 1.0087e-05, 5.1636e-06, 7.0330e-06, 5.6599e-06, 6.8874e-06,
+ 9.7419e-06, 1.8101e-05, 1.2978e-05, 1.5947e-05, 1.4863e-05, 5.4070e-06,
+ 3.5708e-06, 1.6960e-05, 1.7095e-05, 1.3345e-05, 1.7271e-05, 1.4926e-05,
+ 8.5931e-06, 1.1095e-05, 5.7805e-06, 1.9978e-05, 1.6887e-05, 5.9923e-06,
+ 8.9861e-06, 1.3255e-05], grad_fn=)
+train scale mean tensor([1.0287, 1.0287, 1.0286, 1.0287, 1.0287, 1.0287, 1.0285, 1.0287, 1.0287,
+ 1.0287, 1.0288, 1.0288, 1.0287, 1.0286, 1.0287, 1.0286, 1.0287, 1.0288,
+ 1.0289, 1.0287, 1.0286, 1.0287, 1.0286, 1.0286, 1.0289, 1.0288, 1.0288,
+ 1.0285, 1.0286, 1.0288, 1.0288, 1.0285], grad_fn=)
+epoch: 3, iter: 1900, avg_loss: 46.0540, kl/H(z|x): 0.0592, recon: 45.9948,time elapsed 118.72s
+epoch: 3, iter: 1950, avg_loss: 45.6893, kl/H(z|x): 0.0609, recon: 45.6284,time elapsed 119.82s
+kl weight 1.0000
+VAL --- avg_loss: 44.0410, kl/H(z|x): 0.0468, mi: 0.0664, recon: 43.9941, nll: 44.0410, ppl: 54.8018
+0 active units
+update best loss
+gamma Parameter containing:
+tensor([0.7531, 0.8573, 1.1776, 1.0084, 1.0749, 0.9298, 0.9545, 1.1919, 0.7384,
+ 1.2042, 0.9418, 0.9320, 1.0016, 0.5333, 0.8417, 1.2837, 1.2379, 0.8844,
+ 0.9579, 0.7987, 0.8359, 0.8981, 1.0969, 1.0870, 1.1095, 1.2033, 0.7139,
+ 1.1977, 1.1889, 0.9906, 0.9035, 0.9766], requires_grad=True)
+train loc mean tensor([ 0.0389, -0.0135, -0.0668, -0.0289, -0.0586, 0.0438, 0.0223, 0.0034,
+ 0.0052, 0.0265, 0.0018, 0.0417, -0.0163, 0.0010, 0.0236, -0.0516,
+ -0.0088, -0.0191, -0.0704, -0.0007, 0.0522, 0.0405, 0.0495, -0.0328,
+ -0.0161, -0.0143, 0.0084, -0.0512, 0.0116, 0.0380, -0.0292, -0.0483])
+train scale std tensor([8.3092e-06, 6.8119e-06, 1.7887e-05, 1.3828e-05, 1.8606e-05, 1.7448e-05,
+ 2.2577e-05, 3.3867e-06, 6.2293e-06, 4.3768e-06, 6.3535e-06, 7.3117e-06,
+ 4.1512e-06, 1.5505e-05, 8.7723e-06, 9.0478e-06, 2.0474e-05, 9.4471e-06,
+ 4.0762e-06, 6.4814e-06, 1.3078e-05, 1.9357e-05, 6.2588e-06, 1.1728e-05,
+ 7.7367e-06, 9.2317e-06, 7.4100e-06, 1.7443e-05, 1.0280e-05, 2.9517e-06,
+ 1.0600e-05, 9.2339e-06])
+train scale mean tensor([1.0288, 1.0287, 1.0287, 1.0287, 1.0287, 1.0286, 1.0286, 1.0288, 1.0287,
+ 1.0288, 1.0287, 1.0288, 1.0287, 1.0287, 1.0288, 1.0287, 1.0286, 1.0288,
+ 1.0289, 1.0288, 1.0287, 1.0287, 1.0288, 1.0287, 1.0288, 1.0288, 1.0288,
+ 1.0286, 1.0288, 1.0288, 1.0289, 1.0287])
+TEST --- avg_loss: 44.0432, kl/H(z|x): 0.0468, mi: -0.0153, recon: 43.9964, nll: 44.0432, ppl: 54.8129
+epoch: 4, iter: 2000, avg_loss: 43.9364, kl/H(z|x): 0.0473, recon: 43.8891,time elapsed 129.21s
+epoch: 4, iter: 2050, avg_loss: 46.1033, kl/H(z|x): 0.1925, recon: 45.9108,time elapsed 130.29s
+epoch: 4, iter: 2100, avg_loss: 46.2411, kl/H(z|x): 0.0922, recon: 46.1489,time elapsed 131.37s
+epoch: 4, iter: 2150, avg_loss: 46.2292, kl/H(z|x): 0.0933, recon: 46.1359,time elapsed 132.48s
+epoch: 4, iter: 2200, avg_loss: 46.3031, kl/H(z|x): 0.1519, recon: 46.1512,time elapsed 133.61s
+epoch: 4, iter: 2250, avg_loss: 46.1080, kl/H(z|x): 0.1128, recon: 45.9952,time elapsed 134.71s
+epoch: 4, iter: 2300, avg_loss: 46.2069, kl/H(z|x): 0.1018, recon: 46.1051,time elapsed 135.80s
+epoch: 4, iter: 2350, avg_loss: 46.0232, kl/H(z|x): 0.1626, recon: 45.8606,time elapsed 136.88s
+epoch: 4, iter: 2400, avg_loss: 46.4018, kl/H(z|x): 0.1171, recon: 46.2848,time elapsed 138.01s
+epoch: 4, iter: 2450, avg_loss: 45.8789, kl/H(z|x): 0.0603, recon: 45.8185,time elapsed 139.11s
+kl weight 1.0000
+VAL --- avg_loss: 43.9069, kl/H(z|x): 0.0898, mi: 0.0050, recon: 43.8171, nll: 43.9069, ppl: 54.1382
+0 active units
+update best loss
+gamma Parameter containing:
+tensor([0.7509, 0.8669, 1.1856, 0.9956, 1.0628, 0.9358, 0.9456, 1.1958, 0.7394,
+ 1.2110, 0.9154, 0.9326, 1.0014, 0.5382, 0.8347, 1.2811, 1.2412, 0.8911,
+ 0.9606, 0.8028, 0.8493, 0.8925, 1.1058, 1.0843, 1.1013, 1.2081, 0.7034,
+ 1.2092, 1.1926, 0.9834, 0.9101, 0.9705], requires_grad=True)
+train loc mean tensor([-0.0450, -0.0759, -0.0783, 0.0150, 0.1056, 0.1129, -0.0218, 0.0801,
+ -0.0061, 0.0265, 0.1380, -0.0496, -0.0869, -0.0271, -0.0230, -0.0508,
+ 0.0639, -0.0268, 0.0608, 0.0407, -0.0362, -0.0448, -0.0678, 0.0024,
+ -0.0354, 0.0236, 0.0404, -0.1080, 0.0892, 0.0413, -0.0047, -0.0542])
+train scale std tensor([1.6924e-05, 1.6882e-05, 1.0246e-05, 1.2540e-05, 4.4682e-06, 1.0566e-05,
+ 1.2307e-05, 1.0101e-05, 1.4016e-05, 2.1958e-05, 2.5869e-05, 1.4119e-05,
+ 1.4662e-05, 1.0661e-05, 1.0339e-05, 9.7715e-06, 1.3193e-05, 1.2174e-05,
+ 5.5699e-06, 2.0073e-05, 1.2144e-05, 1.1599e-05, 1.1315e-05, 1.2591e-05,
+ 2.3388e-05, 1.0639e-05, 1.2606e-05, 1.6476e-05, 1.3287e-05, 1.2702e-05,
+ 1.0028e-05, 1.5114e-05])
+train scale mean tensor([1.0286, 1.0285, 1.0287, 1.0287, 1.0288, 1.0287, 1.0287, 1.0286, 1.0285,
+ 1.0286, 1.0285, 1.0286, 1.0286, 1.0288, 1.0286, 1.0288, 1.0285, 1.0287,
+ 1.0289, 1.0286, 1.0287, 1.0288, 1.0287, 1.0287, 1.0284, 1.0286, 1.0288,
+ 1.0286, 1.0285, 1.0287, 1.0286, 1.0285])
+TEST --- avg_loss: 43.9080, kl/H(z|x): 0.0898, mi: 0.0293, recon: 43.8182, nll: 43.9080, ppl: 54.1436
+epoch: 5, iter: 2500, avg_loss: 46.1478, kl/H(z|x): 0.0773, recon: 46.0705,time elapsed 148.53s
+epoch: 5, iter: 2550, avg_loss: 45.7327, kl/H(z|x): 0.0603, recon: 45.6724,time elapsed 149.88s
+epoch: 5, iter: 2600, avg_loss: 45.5843, kl/H(z|x): 0.0852, recon: 45.4991,time elapsed 151.01s
+epoch: 5, iter: 2650, avg_loss: 45.9211, kl/H(z|x): 0.0895, recon: 45.8317,time elapsed 152.14s
+epoch: 5, iter: 2700, avg_loss: 45.5345, kl/H(z|x): 0.0683, recon: 45.4662,time elapsed 153.30s
diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt
new file mode 100644
index 0000000..3fe6613
Binary files /dev/null and b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_0_0_783435_lr1.0/model.pt differ
diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt
new file mode 100644
index 0000000..20aafca
--- /dev/null
+++ b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt
@@ -0,0 +1,77 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.2, reset_dec=False, save_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 98.1471, kl/H(z|x): 22.1174, mi: 0.0398, recon: 76.0297,au 0, time elapsed 3.80s
+epoch: 0, iter: 50, avg_loss: 75.2584, kl/H(z|x): 8.0159, mi: -0.0773, recon: 67.2426,au 0, time elapsed 8.56s
+epoch: 0, iter: 100, avg_loss: 67.1634, kl/H(z|x): 6.2421, mi: -0.0161, recon: 60.9213,au 0, time elapsed 13.36s
+epoch: 0, iter: 150, avg_loss: 65.9610, kl/H(z|x): 6.4145, mi: 0.6229, recon: 59.5465,au 26, time elapsed 20.04s
+epoch: 0, iter: 200, avg_loss: 64.0590, kl/H(z|x): 6.3462, mi: 0.0654, recon: 57.7129,au 7, time elapsed 25.26s
+epoch: 0, iter: 250, avg_loss: 60.5662, kl/H(z|x): 6.2482, mi: -0.0251, recon: 54.3180,au 0, time elapsed 30.56s
+epoch: 0, iter: 300, avg_loss: 58.9862, kl/H(z|x): 6.2296, mi: -0.0323, recon: 52.7566,au 0, time elapsed 35.99s
+epoch: 0, iter: 350, avg_loss: 57.9861, kl/H(z|x): 6.2317, mi: 0.0645, recon: 51.7544,au 0, time elapsed 41.22s
+epoch: 0, iter: 400, avg_loss: 57.7159, kl/H(z|x): 6.2682, mi: -0.0005, recon: 51.4477,au 0, time elapsed 46.25s
+epoch: 0, iter: 450, avg_loss: 56.8888, kl/H(z|x): 6.1970, mi: 0.0484, recon: 50.6919,au 0, time elapsed 53.38s
+kl weight 1.0000
+VAL --- avg_loss: 69.2053, kl/H(z|x): 18.6423, mi: -0.1538, recon: 50.5630, nll: 69.2053, ppl: 539.9051
+0 active units
+update best loss
+TEST --- avg_loss: 69.2140, kl/H(z|x): 18.6423, mi: -0.0284, recon: 50.5717, nll: 69.2140, ppl: 540.3313
+epoch: 1, iter: 500, avg_loss: 55.1085, kl/H(z|x): 6.3617, recon: 48.7468,time elapsed 62.84s
+epoch: 1, iter: 550, avg_loss: 55.3878, kl/H(z|x): 6.1804, recon: 49.2074,time elapsed 64.00s
+epoch: 1, iter: 600, avg_loss: 55.1184, kl/H(z|x): 6.2019, recon: 48.9165,time elapsed 65.20s
+epoch: 1, iter: 650, avg_loss: 54.3091, kl/H(z|x): 6.2124, recon: 48.0967,time elapsed 66.34s
+epoch: 1, iter: 700, avg_loss: 54.2357, kl/H(z|x): 6.5016, recon: 47.7341,time elapsed 67.88s
+epoch: 1, iter: 750, avg_loss: 53.9039, kl/H(z|x): 6.1897, recon: 47.7142,time elapsed 69.46s
+epoch: 1, iter: 800, avg_loss: 53.8167, kl/H(z|x): 6.2032, recon: 47.6135,time elapsed 70.90s
+epoch: 1, iter: 850, avg_loss: 53.5023, kl/H(z|x): 6.2023, recon: 47.2999,time elapsed 73.23s
+epoch: 1, iter: 900, avg_loss: 53.8856, kl/H(z|x): 6.2500, recon: 47.6356,time elapsed 75.30s
+epoch: 1, iter: 950, avg_loss: 53.3251, kl/H(z|x): 6.2294, recon: 47.0958,time elapsed 76.53s
+kl weight 1.0000
+VAL --- avg_loss: 49.8829, kl/H(z|x): 3.8641, mi: 0.0947, recon: 46.0188, nll: 49.8829, ppl: 93.2058
+0 active units
+update best loss
+TEST --- avg_loss: 49.8833, kl/H(z|x): 3.8641, mi: -0.0544, recon: 46.0192, nll: 49.8833, ppl: 93.2088
+epoch: 2, iter: 1000, avg_loss: 53.2816, kl/H(z|x): 5.3442, recon: 47.9374,time elapsed 91.88s
+epoch: 2, iter: 1050, avg_loss: 52.7176, kl/H(z|x): 6.2500, recon: 46.4677,time elapsed 93.11s
+epoch: 2, iter: 1100, avg_loss: 53.0183, kl/H(z|x): 6.2678, recon: 46.7505,time elapsed 94.73s
+epoch: 2, iter: 1150, avg_loss: 53.2235, kl/H(z|x): 6.2141, recon: 47.0094,time elapsed 95.95s
+epoch: 2, iter: 1200, avg_loss: 52.6633, kl/H(z|x): 6.0780, recon: 46.5854,time elapsed 99.24s
+epoch: 2, iter: 1250, avg_loss: 52.9677, kl/H(z|x): 6.1585, recon: 46.8092,time elapsed 100.49s
+epoch: 2, iter: 1300, avg_loss: 52.5459, kl/H(z|x): 6.1299, recon: 46.4160,time elapsed 101.89s
+epoch: 2, iter: 1350, avg_loss: 52.6617, kl/H(z|x): 6.0985, recon: 46.5632,time elapsed 103.14s
+epoch: 2, iter: 1400, avg_loss: 52.0059, kl/H(z|x): 6.0843, recon: 45.9216,time elapsed 104.36s
+epoch: 2, iter: 1450, avg_loss: 52.3689, kl/H(z|x): 6.1425, recon: 46.2264,time elapsed 105.60s
+kl weight 1.0000
+VAL --- avg_loss: 45.3509, kl/H(z|x): 1.3035, mi: 0.0306, recon: 44.0474, nll: 45.3509, ppl: 61.7324
+0 active units
+update best loss
+TEST --- avg_loss: 45.3404, kl/H(z|x): 1.3035, mi: -0.0230, recon: 44.0370, nll: 45.3404, ppl: 61.6738
+epoch: 3, iter: 1500, avg_loss: 49.9944, kl/H(z|x): 5.5341, recon: 44.4602,time elapsed 117.11s
+epoch: 3, iter: 1550, avg_loss: 52.2121, kl/H(z|x): 6.0559, recon: 46.1563,time elapsed 119.99s
+epoch: 3, iter: 1600, avg_loss: 52.0304, kl/H(z|x): 6.0492, recon: 45.9812,time elapsed 121.21s
+epoch: 3, iter: 1650, avg_loss: 52.0216, kl/H(z|x): 6.1220, recon: 45.8996,time elapsed 122.39s
+epoch: 3, iter: 1700, avg_loss: 52.1720, kl/H(z|x): 6.1192, recon: 46.0528,time elapsed 123.50s
+epoch: 3, iter: 1750, avg_loss: 52.4309, kl/H(z|x): 6.0886, recon: 46.3423,time elapsed 124.60s
+epoch: 3, iter: 1800, avg_loss: 52.6185, kl/H(z|x): 6.3136, recon: 46.3049,time elapsed 125.72s
+epoch: 3, iter: 1850, avg_loss: 51.9571, kl/H(z|x): 6.1577, recon: 45.7994,time elapsed 126.91s
+epoch: 3, iter: 1900, avg_loss: 51.9346, kl/H(z|x): 6.1612, recon: 45.7734,time elapsed 128.86s
+epoch: 3, iter: 1950, avg_loss: 51.6526, kl/H(z|x): 6.1017, recon: 45.5510,time elapsed 131.37s
+kl weight 1.0000
+VAL --- avg_loss: 44.7754, kl/H(z|x): 0.8472, mi: 0.0321, recon: 43.9282, nll: 44.7754, ppl: 58.5859
+0 active units
+update best loss
+TEST --- avg_loss: 44.7674, kl/H(z|x): 0.8472, mi: -0.0025, recon: 43.9202, nll: 44.7674, ppl: 58.5432
+epoch: 4, iter: 2000, avg_loss: 50.4556, kl/H(z|x): 6.1846, recon: 44.2711,time elapsed 142.19s
+epoch: 4, iter: 2050, avg_loss: 51.7029, kl/H(z|x): 6.2001, recon: 45.5028,time elapsed 143.48s
+epoch: 4, iter: 2100, avg_loss: 52.0093, kl/H(z|x): 6.0755, recon: 45.9338,time elapsed 144.65s
+epoch: 4, iter: 2150, avg_loss: 51.8900, kl/H(z|x): 6.0952, recon: 45.7948,time elapsed 145.75s
+epoch: 4, iter: 2200, avg_loss: 51.6964, kl/H(z|x): 6.0536, recon: 45.6428,time elapsed 146.84s
+epoch: 4, iter: 2250, avg_loss: 51.6489, kl/H(z|x): 6.1424, recon: 45.5065,time elapsed 148.00s
+epoch: 4, iter: 2300, avg_loss: 51.8192, kl/H(z|x): 6.0458, recon: 45.7734,time elapsed 149.10s
+epoch: 4, iter: 2350, avg_loss: 51.3648, kl/H(z|x): 5.9993, recon: 45.3655,time elapsed 150.19s
+epoch: 4, iter: 2400, avg_loss: 52.3250, kl/H(z|x): 6.2564, recon: 46.0686,time elapsed 151.32s
+epoch: 4, iter: 2450, avg_loss: 51.6578, kl/H(z|x): 6.0313, recon: 45.6266,time elapsed 152.43s
+kl weight 1.0000
+VAL --- avg_loss: 44.6217, kl/H(z|x): 0.7541, mi: 0.1146, recon: 43.8676, nll: 44.6217, ppl: 57.7731
diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt
new file mode 100644
index 0000000..6abdeed
Binary files /dev/null and b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.20_0_0_783435_lr1.0/model.pt differ
diff --git a/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
new file mode 100644
index 0000000..af8a5a9
--- /dev/null
+++ b/models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt
@@ -0,0 +1,10 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', decode_from='', decode_input='', decoding_strategy='greedy', delta_rate=1, device='cpu', enc_nh=50, enc_type='lstm', epochs=50, eval=False, gamma=1.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.5, reset_dec=False, save_path='models/synthetic/synthetic_KL1.00_gamma1.00_dr1.00_nz32_drop0.50_0_0_783435_lr1.0/model.pt', seed=783435, target_kl=-1, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 109.8254, kl/H(z|x): 33.7955, mi: 0.0425, recon: 76.0299,au 0, time elapsed 3.76s
+epoch: 0, iter: 50, avg_loss: 83.8892, kl/H(z|x): 16.9853, mi: -0.0774, recon: 66.9040,au 0, time elapsed 8.47s
+epoch: 0, iter: 100, avg_loss: 75.9278, kl/H(z|x): 15.2441, mi: -0.0223, recon: 60.6837,au 0, time elapsed 13.28s
+epoch: 0, iter: 150, avg_loss: 74.6431, kl/H(z|x): 15.2945, mi: 0.5371, recon: 59.3485,au 20, time elapsed 17.94s
+epoch: 0, iter: 200, avg_loss: 72.5163, kl/H(z|x): 15.3962, mi: -0.0651, recon: 57.1201,au 0, time elapsed 22.74s
diff --git a/models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
new file mode 100644
index 0000000..a5948a4
--- /dev/null
+++ b/models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
@@ -0,0 +1,9 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 89.9236, kl/H(z|x): 13.8949, mi: 0.0091, recon: 76.0287,au 0, time elapsed 3.66s
+epoch: 0, iter: 50, avg_loss: 75.8018, kl/H(z|x): 9.7718, mi: 0.0997, recon: 66.0301,au 0, time elapsed 8.55s
+epoch: 0, iter: 100, avg_loss: 69.6457, kl/H(z|x): 9.1719, mi: -0.0288, recon: 60.4737,au 0, time elapsed 13.30s
+epoch: 0, iter: 150, avg_loss: 68.5756, kl/H(z|x): 9.0761, mi: -0.0296, recon: 59.4994,au 0, time elapsed 18.10s
diff --git a/models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
new file mode 100644
index 0000000..4f8fdca
--- /dev/null
+++ b/models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
@@ -0,0 +1,9 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.0, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_fb1_tr0.15_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.15, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 89.9236, kl/H(z|x): 13.8949, mi: 0.0091, recon: 76.0287,au 0, time elapsed 5.64s
+epoch: 0, iter: 50, avg_loss: 76.0595, kl/H(z|x): 9.8928, mi: 0.0997, recon: 66.1668,au 0, time elapsed 10.54s
+epoch: 0, iter: 100, avg_loss: 69.7483, kl/H(z|x): 9.2888, mi: -0.0287, recon: 60.4595,au 0, time elapsed 15.83s
+epoch: 0, iter: 150, avg_loss: 68.6088, kl/H(z|x): 9.1824, mi: -0.0296, recon: 59.4264,au 0, time elapsed 20.88s
diff --git a/models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt
new file mode 100644
index 0000000..6fb79c7
--- /dev/null
+++ b/models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt
@@ -0,0 +1,11 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=0.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.0, gamma_train=False, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0, save_path='models/synthetic/synthetic_fd2_fw2_dr0.00_nz32_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 84.0465, kl/H(z|x): 8.0181, mi: -0.0320, recon: 76.0285,au 0, time elapsed 3.85s
+epoch: 0, iter: 50, avg_loss: 67.9761, kl/H(z|x): 1.1939, mi: 0.0944, recon: 66.7822,au 0, time elapsed 8.59s
+epoch: 0, iter: 100, avg_loss: 60.7799, kl/H(z|x): 0.1675, mi: -0.0444, recon: 60.6124,au 0, time elapsed 13.57s
+epoch: 0, iter: 150, avg_loss: 59.6752, kl/H(z|x): 0.1535, mi: -0.0352, recon: 59.5216,au 0, time elapsed 18.29s
+epoch: 0, iter: 200, avg_loss: 56.9132, kl/H(z|x): 0.1076, mi: 0.0690, recon: 56.8057,au 0, time elapsed 23.04s
+epoch: 0, iter: 250, avg_loss: 54.5933, kl/H(z|x): 0.1175, mi: 0.0078, recon: 54.4758,au 0, time elapsed 27.81s
diff --git a/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
new file mode 100644
index 0000000..011c7b5
--- /dev/null
+++ b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
@@ -0,0 +1,7 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=1, flow_depth=2, flow_width=2, gamma=0.8, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 90.0487, kl/H(z|x): 14.0200, mi: -0.0322, recon: 76.0287,au 0, time elapsed 3.93s
+epoch: 0, iter: 50, avg_loss: 81.0231, kl/H(z|x): 14.3386, mi: 0.7551, recon: 66.6845,au 12, time elapsed 8.63s
diff --git a/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt
new file mode 100644
index 0000000..4d782c9
Binary files /dev/null and b/models/synthetic/synthetic_gamma0.80_fb1_tr0.00_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt differ
diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
new file mode 100644
index 0000000..4ed17a2
--- /dev/null
+++ b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
@@ -0,0 +1,24 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=0.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.8, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 84.2557, kl/H(z|x): 8.2271, mi: -0.0313, recon: 76.0286,au 0, time elapsed 3.82s
+epoch: 0, iter: 50, avg_loss: 70.6074, kl/H(z|x): 4.2926, mi: 0.7366, recon: 66.3148,au 22, time elapsed 9.52s
+epoch: 0, iter: 100, avg_loss: 62.1741, kl/H(z|x): 2.6194, mi: 0.6129, recon: 59.5546,au 23, time elapsed 14.88s
+epoch: 0, iter: 150, avg_loss: 60.9270, kl/H(z|x): 2.6617, mi: 0.6679, recon: 58.2653,au 23, time elapsed 20.14s
+epoch: 0, iter: 200, avg_loss: 58.7194, kl/H(z|x): 2.2669, mi: 0.6364, recon: 56.4524,au 24, time elapsed 25.00s
+epoch: 0, iter: 250, avg_loss: 55.7990, kl/H(z|x): 1.8970, mi: 0.4772, recon: 53.9020,au 20, time elapsed 29.83s
+epoch: 0, iter: 300, avg_loss: 54.3986, kl/H(z|x): 1.9483, mi: 0.3116, recon: 52.4503,au 19, time elapsed 34.72s
+epoch: 0, iter: 350, avg_loss: 53.3452, kl/H(z|x): 1.8530, mi: 0.5277, recon: 51.4922,au 19, time elapsed 39.95s
+epoch: 0, iter: 400, avg_loss: 53.2150, kl/H(z|x): 1.7295, mi: 0.2439, recon: 51.4855,au 16, time elapsed 45.09s
+epoch: 0, iter: 450, avg_loss: 52.9093, kl/H(z|x): 1.8259, mi: 0.3234, recon: 51.0834,au 18, time elapsed 50.17s
+kl weight 1.0000
+VAL --- avg_loss: 48.8077, kl/H(z|x): 1.2442, mi: 0.2892, recon: 47.5353, nll: 48.7795, ppl: 84.3100
+20 active units
+update best loss
+TEST --- avg_loss: 48.8295, kl/H(z|x): 1.2387, mi: 0.2400, recon: 47.5826, nll: 48.8212, ppl: 84.6304
+epoch: 1, iter: 500, avg_loss: 47.8222, kl/H(z|x): 1.2917, recon: 46.5305,time elapsed 60.22s
+epoch: 1, iter: 550, avg_loss: 51.4242, kl/H(z|x): 1.7375, recon: 49.6867,time elapsed 61.58s
+epoch: 1, iter: 600, avg_loss: 51.8762, kl/H(z|x): 1.6511, recon: 50.2251,time elapsed 64.43s
+epoch: 1, iter: 650, avg_loss: 50.9553, kl/H(z|x): 1.6560, recon: 49.2993,time elapsed 66.71s
diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt
new file mode 100644
index 0000000..a1cabc4
Binary files /dev/null and b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr0.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt differ
diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
new file mode 100644
index 0000000..a874cc7
--- /dev/null
+++ b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt
@@ -0,0 +1,28 @@
+Namespace(batch_size=32, cuda=False, dataset='synthetic', dec_dropout_in=0.5, dec_dropout_out=0.5, dec_nh=50, dec_type='lstm', delta_rate=1.0, device='cpu', drop_start=1.0, enc_nh=50, enc_type='lstm', epochs=50, eval=False, fb=0, flow_depth=2, flow_width=2, gamma=0.8, iw_nsamples=500, jobid=0, kl_start=1.0, label=True, load_path='', log_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/log.txt', lr=1.0, momentum=0, ni=50, nsamples=1, nz=32, nz_new=32, p_drop=0.3, save_path='models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt', seed=783435, target_kl=0.0, taskid=0, test_data='data/synthetic_data/synthetic_test.txt', test_nepoch=1, train_data='data/synthetic_data/synthetic_train.txt', val_data='data/synthetic_data/synthetic_test.txt', vocab_file='data/synthetic_data/vocab.txt', warm_up=100)
+data/synthetic_data/vocab.txt
+Train data: 16000 samples
+finish reading datasets, vocab size is 1004
+dropped sentences: 0
+epoch: 0, iter: 0, avg_loss: 90.0487, kl/H(z|x): 14.0200, mi: 0.0092, recon: 76.0288,au 0, time elapsed 3.77s
+epoch: 0, iter: 50, avg_loss: 80.6753, kl/H(z|x): 14.6578, mi: 0.8141, recon: 66.0176,au 18, time elapsed 8.55s
+epoch: 0, iter: 100, avg_loss: 72.3764, kl/H(z|x): 13.0901, mi: 1.0537, recon: 59.2863,au 19, time elapsed 13.41s
+epoch: 0, iter: 150, avg_loss: 71.4821, kl/H(z|x): 13.4491, mi: 0.6681, recon: 58.0330,au 19, time elapsed 18.23s
+epoch: 0, iter: 200, avg_loss: 68.6188, kl/H(z|x): 12.5864, mi: 0.6411, recon: 56.0324,au 18, time elapsed 23.76s
+epoch: 0, iter: 250, avg_loss: 65.7556, kl/H(z|x): 12.0306, mi: 0.8201, recon: 53.7249,au 17, time elapsed 29.69s
+epoch: 0, iter: 300, avg_loss: 64.6366, kl/H(z|x): 12.2975, mi: 0.6563, recon: 52.3392,au 21, time elapsed 34.73s
+epoch: 0, iter: 350, avg_loss: 63.6333, kl/H(z|x): 12.1879, mi: 0.8044, recon: 51.4454,au 21, time elapsed 39.77s
+epoch: 0, iter: 400, avg_loss: 63.1732, kl/H(z|x): 12.0782, mi: 0.5898, recon: 51.0949,au 15, time elapsed 44.60s
+epoch: 0, iter: 450, avg_loss: 62.4059, kl/H(z|x): 11.7999, mi: 0.6608, recon: 50.6061,au 15, time elapsed 50.74s
+kl weight 1.0000
+VAL --- avg_loss: 51.0809, kl/H(z|x): 2.9957, mi: 0.6483, recon: 48.0950, nll: 51.0907, ppl: 104.0226
+22 active units
+update best loss
+TEST --- avg_loss: 51.0537, kl/H(z|x): 2.9736, mi: 0.6390, recon: 48.0735, nll: 51.0471, ppl: 103.6111
+epoch: 1, iter: 500, avg_loss: 59.4675, kl/H(z|x): 11.8408, recon: 47.6266,time elapsed 61.07s
+epoch: 1, iter: 550, avg_loss: 60.6828, kl/H(z|x): 11.7920, recon: 48.8908,time elapsed 62.34s
+epoch: 1, iter: 600, avg_loss: 60.5349, kl/H(z|x): 11.8571, recon: 48.6779,time elapsed 63.65s
+epoch: 1, iter: 650, avg_loss: 59.8986, kl/H(z|x): 11.9386, recon: 47.9600,time elapsed 64.92s
+epoch: 1, iter: 700, avg_loss: 59.4865, kl/H(z|x): 11.8225, recon: 47.6640,time elapsed 66.25s
+epoch: 1, iter: 750, avg_loss: 59.0838, kl/H(z|x): 11.7320, recon: 47.3517,time elapsed 67.57s
+epoch: 1, iter: 800, avg_loss: 58.9415, kl/H(z|x): 11.8043, recon: 47.1373,time elapsed 68.85s
+epoch: 1, iter: 850, avg_loss: 58.6082, kl/H(z|x): 11.6145, recon: 46.9938,time elapsed 70.12s
diff --git a/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt
new file mode 100644
index 0000000..c7c69b7
Binary files /dev/null and b/models/synthetic/synthetic_gamma0.80_fd2_fw2_dr1.00_nz32_drop0.30_0_0_783435_lr1.0_IAF/model.pt differ
diff --git a/modules/.DS_Store b/modules/.DS_Store
new file mode 100644
index 0000000..78526d3
Binary files /dev/null and b/modules/.DS_Store differ
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100755
index 0000000..62f9459
--- /dev/null
+++ b/modules/__init__.py
@@ -0,0 +1,6 @@
+from .encoders import *
+from .decoders import *
+from .vae import *
+from .vae_IAF import *
+from .utils import *
+from .discriminators import *
\ No newline at end of file
diff --git a/modules/decoders/__init__.py b/modules/decoders/__init__.py
new file mode 100755
index 0000000..eb2c937
--- /dev/null
+++ b/modules/decoders/__init__.py
@@ -0,0 +1,3 @@
+from .dec_lstm import *
+from .dec_pixelcnn import *
+from .dec_pixelcnn_v2 import *
\ No newline at end of file
diff --git a/modules/decoders/dec_lstm.py b/modules/decoders/dec_lstm.py
new file mode 100755
index 0000000..c9fc172
--- /dev/null
+++ b/modules/decoders/dec_lstm.py
@@ -0,0 +1,477 @@
+# import torch
+
+import time
+import argparse
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
+
+import numpy as np
+
+from .decoder import DecoderBase
+from .decoder_helper import BeamSearchNode
+
+class LSTMDecoder(DecoderBase):
+ """LSTM decoder with constant-length batching"""
+ def __init__(self, args, vocab, model_init, emb_init):
+ super(LSTMDecoder, self).__init__()
+ self.ni = args.ni
+ self.nh = args.dec_nh
+ self.nz = args.nz
+ self.vocab = vocab
+ self.device = args.device
+
+ # no padding when setting padding_idx to -1
+ self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1)
+
+ self.dropout_in = nn.Dropout(args.dec_dropout_in)
+ self.dropout_out = nn.Dropout(args.dec_dropout_out)
+
+ # for initializing hidden state and cell
+ self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False)
+
+ # concatenate z with input
+ self.lstm = nn.LSTM(input_size=args.ni + args.nz,
+ hidden_size=args.dec_nh,
+ num_layers=1,
+ batch_first=True)
+
+ # prediction layer
+ self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False)
+
+ vocab_mask = torch.ones(len(vocab))
+ # vocab_mask[vocab['']] = 0
+ self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
+
+ self.reset_parameters(model_init, emb_init)
+
+ def reset_parameters(self, model_init, emb_init):
+ # for name, param in self.lstm.named_parameters():
+ # # self.initializer(param)
+ # if 'bias' in name:
+ # nn.init.constant_(param, 0.0)
+ # # model_init(param)
+ # elif 'weight' in name:
+ # model_init(param)
+
+ # model_init(self.trans_linear.weight)
+ # model_init(self.pred_linear.weight)
+ for param in self.parameters():
+ model_init(param)
+ emb_init(self.embed.weight)
+
+ def decode(self, input, z):
+ """
+ Args:
+ input: (batch_size, seq_len)
+ z: (batch_size, n_sample, nz)
+ """
+
+ # not predicting start symbol
+ # sents_len -= 1
+
+ batch_size, n_sample, _ = z.size()
+ seq_len = input.size(1)
+
+ # (batch_size, seq_len, ni)
+ word_embed = self.embed(input)
+ word_embed = self.dropout_in(word_embed)
+
+ if n_sample == 1:
+ z_ = z.expand(batch_size, seq_len, self.nz)
+
+ else:
+ word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
+ .contiguous()
+
+ # (batch_size * n_sample, seq_len, ni)
+ word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
+
+ z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
+ z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
+
+ # (batch_size * n_sample, seq_len, ni + nz)
+ word_embed = torch.cat((word_embed, z_), -1)
+
+ z = z.view(batch_size * n_sample, self.nz)
+ c_init = self.trans_linear(z).unsqueeze(0)
+ h_init = torch.tanh(c_init)
+ # h_init = self.trans_linear(z).unsqueeze(0)
+ # c_init = h_init.new_zeros(h_init.size())
+ output, _ = self.lstm(word_embed, (h_init, c_init))
+
+ output = self.dropout_out(output)
+
+ # (batch_size * n_sample, seq_len, vocab_size)
+ output_logits = self.pred_linear(output)
+
+ return output_logits
+
+ def reconstruct_error(self, x, z):
+ """Cross Entropy in the language case
+ Args:
+ x: (batch_size, seq_len)
+ z: (batch_size, n_sample, nz)
+ Returns:
+ loss: (batch_size, n_sample). Loss
+ across different sentence and z
+ """
+
+ #remove end symbol
+ src = x[:, :-1]
+
+ # remove start symbol
+ tgt = x[:, 1:]
+
+ batch_size, seq_len = src.size()
+ n_sample = z.size(1)
+
+ # (batch_size * n_sample, seq_len, vocab_size)
+ output_logits = self.decode(src, z)
+
+ if n_sample == 1:
+ tgt = tgt.contiguous().view(-1)
+ else:
+ # (batch_size * n_sample * seq_len)
+ tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
+ .contiguous().view(-1)
+
+ # (batch_size * n_sample * seq_len)
+ loss = self.loss(output_logits.view(-1, output_logits.size(2)),
+ tgt)
+
+
+ # (batch_size, n_sample)
+ return loss.view(batch_size, n_sample, -1).sum(-1)
+
+
+ def log_probability(self, x, z):
+ """Cross Entropy in the language case
+ Args:
+ x: (batch_size, seq_len)
+ z: (batch_size, n_sample, nz)
+ Returns:
+ log_p: (batch_size, n_sample).
+ log_p(x|z) across different x and z
+ """
+
+ return -self.reconstruct_error(x, z)
+
+ def beam_search_decode(self, z, K=5):
+ """beam search decoding, code is based on
+ https://github.com/pcyin/pytorch_basic_nmt/blob/master/nmt.py
+
+ the current implementation decodes sentence one by one, further batching would improve the speed
+
+ Args:
+ z: (batch_size, nz)
+ K: the beam width
+
+ Returns: List1
+ List1: the decoded word sentence list
+ """
+
+ decoded_batch = []
+ batch_size, nz = z.size()
+
+ # (1, batch_size, nz)
+ c_init = self.trans_linear(z).unsqueeze(0)
+ h_init = torch.tanh(c_init)
+
+ # decoding goes sentence by sentence
+ for idx in range(batch_size):
+ # Start with the start of the sentence token
+ decoder_input = torch.tensor([[self.vocab[""]]], dtype=torch.long, device=self.device)
+ decoder_hidden = (h_init[:,idx,:].unsqueeze(1), c_init[:,idx,:].unsqueeze(1))
+
+ node = BeamSearchNode(decoder_hidden, None, decoder_input, 0., 1)
+ live_hypotheses = [node]
+
+ completed_hypotheses = []
+
+ t = 0
+ while len(completed_hypotheses) < K and t < 100:
+ t += 1
+
+ # (len(live), 1)
+ decoder_input = torch.cat([node.wordid for node in live_hypotheses], dim=0)
+
+ # (1, len(live), nh)
+ decoder_hidden_h = torch.cat([node.h[0] for node in live_hypotheses], dim=1)
+ decoder_hidden_c = torch.cat([node.h[1] for node in live_hypotheses], dim=1)
+
+ decoder_hidden = (decoder_hidden_h, decoder_hidden_c)
+
+
+ # (len(live), 1, ni) --> (len(live), 1, ni+nz)
+ word_embed = self.embed(decoder_input)
+ word_embed = torch.cat((word_embed, z[idx].view(1, 1, -1).expand(
+ len(live_hypotheses), 1, nz)), dim=-1)
+
+ output, decoder_hidden = self.lstm(word_embed, decoder_hidden)
+
+ # (len(live), 1, vocab_size)
+ output_logits = self.pred_linear(output)
+ decoder_output = F.log_softmax(output_logits, dim=-1)
+
+ prev_logp = torch.tensor([node.logp for node in live_hypotheses], dtype=torch.float, device=self.device)
+ decoder_output = decoder_output + prev_logp.view(len(live_hypotheses), 1, 1)
+
+ # (len(live) * vocab_size)
+ decoder_output = decoder_output.view(-1)
+
+ # (K)
+ log_prob, indexes = torch.topk(decoder_output, K-len(completed_hypotheses))
+
+ live_ids = indexes // len(self.vocab)
+ word_ids = indexes % len(self.vocab)
+
+ live_hypotheses_new = []
+ for live_id, word_id, log_prob_ in zip(live_ids, word_ids, log_prob):
+ node = BeamSearchNode((decoder_hidden[0][:, live_id, :].unsqueeze(1),
+ decoder_hidden[1][:, live_id, :].unsqueeze(1)),
+ live_hypotheses[live_id], word_id.view(1, 1), log_prob_, t)
+
+ if word_id.item() == self.vocab[""]:
+ completed_hypotheses.append(node)
+ else:
+ live_hypotheses_new.append(node)
+
+ live_hypotheses = live_hypotheses_new
+
+ if len(completed_hypotheses) == K:
+ break
+
+ for live in live_hypotheses:
+ completed_hypotheses.append(live)
+
+ utterances = []
+ for n in sorted(completed_hypotheses, key=lambda node: node.logp, reverse=True):
+ utterance = []
+ utterance.append(self.vocab.id2word(n.wordid.item()))
+ # back trace
+ while n.prevNode != None:
+ n = n.prevNode
+ utterance.append(self.vocab.id2word(n.wordid.item()))
+
+ utterance = utterance[::-1]
+
+ utterances.append(utterance)
+
+ # only save the top 1
+ break
+
+ decoded_batch.append(utterances[0])
+
+ return decoded_batch
+
+ def greedy_decode(self, z):
+ """greedy decoding from z
+ Args:
+ z: (batch_size, nz)
+
+ Returns: List1
+ List1: the decoded word sentence list
+ """
+
+ batch_size = z.size(0)
+ decoded_batch = [[] for _ in range(batch_size)]
+
+ # (batch_size, 1, nz)
+ c_init = self.trans_linear(z).unsqueeze(0)
+ h_init = torch.tanh(c_init)
+
+ decoder_hidden = (h_init, c_init)
+ decoder_input = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1)
+ end_symbol = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device)
+
+ mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device)
+ length_c = 1
+ while mask.sum().item() != 0 and length_c < 100:
+
+ # (batch_size, 1, ni) --> (batch_size, 1, ni+nz)
+ word_embed = self.embed(decoder_input)
+ word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1)
+
+ output, decoder_hidden = self.lstm(word_embed, decoder_hidden)
+
+ # (batch_size, 1, vocab_size) --> (batch_size, vocab_size)
+ decoder_output = self.pred_linear(output)
+ output_logits = decoder_output.squeeze(1)
+
+ # (batch_size)
+ max_index = torch.argmax(output_logits, dim=1)
+ # max_index = torch.multinomial(probs, num_samples=1)
+
+ decoder_input = max_index.unsqueeze(1)
+ length_c += 1
+
+ for i in range(batch_size):
+ if mask[i].item():
+ decoded_batch[i].append(self.vocab.id2word(max_index[i].item()))
+
+ mask = torch.mul((max_index != end_symbol), mask)
+
+ return decoded_batch
+
+ def sample_decode(self, z):
+ """sampling decoding from z
+ Args:
+ z: (batch_size, nz)
+
+ Returns: List1
+ List1: the decoded word sentence list
+ """
+
+ batch_size = z.size(0)
+ decoded_batch = [[] for _ in range(batch_size)]
+
+ # (batch_size, 1, nz)
+ c_init = self.trans_linear(z).unsqueeze(0)
+ h_init = torch.tanh(c_init)
+
+ decoder_hidden = (h_init, c_init)
+ decoder_input = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1)
+ end_symbol = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device)
+
+ mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device)
+ length_c = 1
+ while mask.sum().item() != 0 and length_c < 100:
+
+ # (batch_size, 1, ni) --> (batch_size, 1, ni+nz)
+ word_embed = self.embed(decoder_input)
+ word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1)
+
+ output, decoder_hidden = self.lstm(word_embed, decoder_hidden)
+
+ # (batch_size, 1, vocab_size) --> (batch_size, vocab_size)
+ decoder_output = self.pred_linear(output)
+ output_logits = decoder_output.squeeze(1)
+
+ # (batch_size)
+ sample_prob = F.softmax(output_logits, dim=1)
+ sample_index = torch.multinomial(sample_prob, num_samples=1).squeeze(1)
+
+ decoder_input = sample_index.unsqueeze(1)
+ length_c += 1
+
+ for i in range(batch_size):
+ if mask[i].item():
+ decoded_batch[i].append(self.vocab.id2word(sample_index[i].item()))
+
+ mask = torch.mul((sample_index != end_symbol), mask)
+
+ return decoded_batch
+
+
+class VarLSTMDecoder(LSTMDecoder):
+ """LSTM decoder with variable-length batching"""
+ def __init__(self, args, vocab, model_init, emb_init):
+ super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init)
+
+ self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab[''])
+ vocab_mask = torch.ones(len(vocab))
+ vocab_mask[vocab['']] = 0
+ self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
+
+ self.reset_parameters(model_init, emb_init)
+
+ def decode(self, input, z):
+ """
+ Args:
+ input: tuple which contains x and sents_len
+ x: (batch_size, seq_len)
+ sents_len: long tensor of sentence lengths
+ z: (batch_size, n_sample, nz)
+ """
+
+ input, sents_len = input
+
+ # not predicting start symbol
+ sents_len = sents_len - 1
+
+ batch_size, n_sample, _ = z.size()
+ seq_len = input.size(1)
+
+ # (batch_size, seq_len, ni)
+ word_embed = self.embed(input)
+ word_embed = self.dropout_in(word_embed)
+
+ if n_sample == 1:
+ z_ = z.expand(batch_size, seq_len, self.nz)
+
+ else:
+ word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
+ .contiguous()
+
+ # (batch_size * n_sample, seq_len, ni)
+ word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
+
+ z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
+ z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
+
+ # (batch_size * n_sample, seq_len, ni + nz)
+ word_embed = torch.cat((word_embed, z_), -1)
+
+ sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1)
+ packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True)
+
+ z = z.view(batch_size * n_sample, self.nz)
+ # h_init = self.trans_linear(z).unsqueeze(0)
+ # c_init = h_init.new_zeros(h_init.size())
+ c_init = self.trans_linear(z).unsqueeze(0)
+ h_init = torch.tanh(c_init)
+ output, _ = self.lstm(packed_embed, (h_init, c_init))
+ output, _ = pad_packed_sequence(output, batch_first=True)
+
+ output = self.dropout_out(output)
+
+ # (batch_size * n_sample, seq_len, vocab_size)
+ output_logits = self.pred_linear(output)
+
+ return output_logits
+
+ def reconstruct_error(self, x, z):
+ """Cross Entropy in the language case
+ Args:
+ x: tuple which contains x_ and sents_len
+ x_: (batch_size, seq_len)
+ sents_len: long tensor of sentence lengths
+ z: (batch_size, n_sample, nz)
+ Returns:
+ loss: (batch_size, n_sample). Loss
+ across different sentence and z
+ """
+
+ x, sents_len = x
+
+ #remove end symbol
+ src = x[:, :-1]
+
+ # remove start symbol
+ tgt = x[:, 1:]
+
+ batch_size, seq_len = src.size()
+ n_sample = z.size(1)
+
+ # (batch_size * n_sample, seq_len, vocab_size)
+ output_logits = self.decode((src, sents_len), z)
+
+ if n_sample == 1:
+ tgt = tgt.contiguous().view(-1)
+ else:
+ # (batch_size * n_sample * seq_len)
+ tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
+ .contiguous().view(-1)
+
+ # (batch_size * n_sample * seq_len)
+ loss = self.loss(output_logits.view(-1, output_logits.size(2)),
+ tgt)
+
+
+ # (batch_size, n_sample)
+ return loss.view(batch_size, n_sample, -1).sum(-1)
+
diff --git a/modules/decoders/dec_pixelcnn.py b/modules/decoders/dec_pixelcnn.py
new file mode 100755
index 0000000..d69ac66
--- /dev/null
+++ b/modules/decoders/dec_pixelcnn.py
@@ -0,0 +1,156 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from .decoder import DecoderBase
+
+def he_init(m):
+ s = np.sqrt(2./ m.in_features)
+ m.weight.data.normal_(0, s)
+
+class GatedMaskedConv2d(nn.Module):
+ def __init__(self, in_dim, out_dim=None, kernel_size = 3, mask = 'B'):
+ super(GatedMaskedConv2d, self).__init__()
+ if out_dim is None:
+ out_dim = in_dim
+ self.dim = out_dim
+ self.size = kernel_size
+ self.mask = mask
+ pad = self.size // 2
+
+ #vertical stack
+ self.v_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(pad+1, self.size))
+ self.v_pad1 = nn.ConstantPad2d((pad, pad, pad, 0), 0)
+ self.v_pad2 = nn.ConstantPad2d((0, 0, 1, 0), 0)
+ self.vh_conv = nn.Conv2d(2*self.dim, 2*self.dim, kernel_size = 1)
+
+ #horizontal stack
+ self.h_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(1, pad+1))
+ self.h_pad1 = nn.ConstantPad2d((self.size // 2, 0, 0, 0), 0)
+ self.h_pad2 = nn.ConstantPad2d((1, 0, 0, 0), 0)
+ self.h_conv_res = nn.Conv2d(self.dim, self.dim, 1)
+
+ def forward(self, v_map, h_map):
+ v_out = self.v_pad2(self.v_conv(self.v_pad1(v_map)))[:, :, :-1, :]
+ v_map_out = F.tanh(v_out[:, :self.dim])*F.sigmoid(v_out[:, self.dim:])
+ vh = self.vh_conv(v_out)
+
+ h_out = self.h_conv(self.h_pad1(h_map))
+ if self.mask == 'A':
+ h_out = self.h_pad2(h_out)[:, :, :, :-1]
+ h_out = h_out + vh
+ h_out = F.tanh(h_out[:, :self.dim])*F.sigmoid(h_out[:, self.dim:])
+ h_map_out = self.h_conv_res(h_out)
+ if self.mask == 'B':
+ h_map_out = h_map_out + h_map
+ return v_map_out, h_map_out
+
+class StackedGatedMaskedConv2d(nn.Module):
+ def __init__(self,
+ img_size = [1, 28, 28], layers = [64,64,64],
+ kernel_size = [7,7,7], latent_dim=64, latent_feature_map = 1):
+ super(StackedGatedMaskedConv2d, self).__init__()
+ input_dim = img_size[0]
+ self.conv_layers = []
+ if latent_feature_map > 0:
+ self.latent_feature_map = latent_feature_map
+ self.z_linear = nn.Linear(latent_dim, latent_feature_map*28*28)
+ for i in range(len(kernel_size)):
+ if i == 0:
+ self.conv_layers.append(GatedMaskedConv2d(input_dim+latent_feature_map,
+ layers[i], kernel_size[i], 'A'))
+ else:
+ self.conv_layers.append(GatedMaskedConv2d(layers[i-1], layers[i], kernel_size[i]))
+
+ self.modules = nn.ModuleList(self.conv_layers)
+
+ def forward(self, img, q_z=None):
+ """
+ Args:
+ img: (batch, nc, H, W)
+ q_z: (batch, nsamples, nz)
+ """
+
+ batch_size, nsamples, _ = q_z.size()
+ if q_z is not None:
+ z_img = self.z_linear(q_z)
+ z_img = z_img.view(img.size(0), nsamples, self.latent_feature_map, img.size(2), img.size(3))
+
+ # (batch, nsamples, nc, H, W)
+ img = img.unsqueeze(1).expand(batch_size, nsamples, *img.size()[1:])
+
+ for i in range(len(self.conv_layers)):
+ if i == 0:
+ if q_z is not None:
+ # (batch, nsamples, nc + fm, H, W) --> (batch * nsamples, nc + fm, H, W)
+ v_map = torch.cat([img, z_img], 2)
+ v_map = v_map.view(-1, *v_map.size()[2:])
+ else:
+ v_map = img
+ h_map = v_map
+ v_map, h_map = self.conv_layers[i](v_map, h_map)
+ return h_map
+
+class PixelCNNDecoder(DecoderBase):
+ """docstring for PixelCNNDecoder"""
+ def __init__(self, args):
+ super(PixelCNNDecoder, self).__init__()
+ self.dec_cnn = StackedGatedMaskedConv2d(img_size=args.img_size, layers = args.dec_layers,
+ latent_dim= args.nz, kernel_size = args.dec_kernel_size,
+ latent_feature_map = args.latent_feature_map)
+
+ self.dec_linear = nn.Conv2d(args.dec_layers[-1], args.img_size[0], kernel_size = 1)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ he_init(m)
+
+ def decode(self, img, q_z):
+ dec_cnn_output = self.dec_cnn(img, q_z)
+ pred = F.sigmoid(self.dec_linear(dec_cnn_output))
+ return pred
+
+ def reconstruct_error(self, x, z):
+ """Cross Entropy in the language case
+ Args:
+ x: (batch_size, nc, H, W)
+ z: (batch_size, n_sample, nz)
+ Returns:
+ loss: (batch_size, n_sample). Loss
+ across different sentence and z
+ """
+
+ batch_size, nsamples, _ = z.size()
+ # (batch * nsamples, nc, H, W)
+ pred = self.decode(x, z)
+ prob = torch.clamp(pred.view(pred.size(0), -1), min=1e-5, max=1.-1e-5)
+
+ # (batch, nsamples, nc, H, W) --> (batch * nsamples, nc, H, W)
+ x = x.unsqueeze(1).expand(batch_size, nsamples, *x.size()[1:]).contiguous()
+ tgt_vec = x.view(-1, *x.size()[2:])
+
+ # (batch * nsamples, *)
+ tgt_vec = tgt_vec.view(tgt_vec.size(0), -1)
+
+ log_bernoulli = tgt_vec * torch.log(prob) + (1. - tgt_vec)*torch.log(1. - prob)
+
+ log_bernoulli = log_bernoulli.view(batch_size, nsamples, -1)
+
+ return -torch.sum(log_bernoulli, 2)
+
+
+ def log_probability(self, x, z):
+ """Cross Entropy in the language case
+ Args:
+ x: (batch_size, nc, H, W)
+ z: (batch_size, n_sample, nz)
+ Returns:
+ log_p: (batch_size, n_sample).
+ log_p(x|z) across different x and z
+ """
+
+ return -self.reconstruct_error(x, z)
diff --git a/modules/decoders/dec_pixelcnn_v2.py b/modules/decoders/dec_pixelcnn_v2.py
new file mode 100755
index 0000000..6987511
--- /dev/null
+++ b/modules/decoders/dec_pixelcnn_v2.py
@@ -0,0 +1,232 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+import numpy as np
+
+from .decoder import DecoderBase
+
+class MaskedConv2d(nn.Conv2d):
+ def __init__(self, mask_type, masked_channels, *args, **kwargs):
+ super(MaskedConv2d, self).__init__(*args, **kwargs)
+ assert mask_type in {'A', 'B'}
+ self.register_buffer('mask', self.weight.data.clone())
+ _, _, kH, kW = self.weight.size()
+ self.mask.fill_(1)
+ self.mask[:, :masked_channels, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
+ self.mask[:, :masked_channels, kH // 2 + 1:] = 0
+
+ def reset_parameters(self):
+ n = self.kernel_size[0] * self.kernel_size[1] * self.out_channels
+ self.weight.data.normal_(0, math.sqrt(2. / n))
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x):
+ self.weight.data.mul_(self.mask)
+ return super(MaskedConv2d, self).forward(x)
+
+class PixelCNNBlock(nn.Module):
+ def __init__(self, in_channels, kernel_size):
+ super(PixelCNNBlock, self).__init__()
+ self.mask_type = 'B'
+ padding = kernel_size // 2
+ out_channels = in_channels // 2
+
+ self.main = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ELU(),
+ MaskedConv2d(self.mask_type, out_channels, out_channels, out_channels, kernel_size, padding=padding, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ELU(),
+ nn.Conv2d(out_channels, in_channels, 1, bias=False),
+ nn.BatchNorm2d(in_channels),
+ )
+ self.activation = nn.ELU()
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def forward(self, input):
+ return self.activation(self.main(input) + input)
+
+
+class MaskABlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, masked_channels):
+ super(MaskABlock, self).__init__()
+ self.mask_type = 'A'
+ padding = kernel_size // 2
+
+ self.main = nn.Sequential(
+ MaskedConv2d(self.mask_type, masked_channels, in_channels, out_channels, kernel_size, padding=padding, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ELU(),
+ )
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ m = self.main[1]
+ assert isinstance(m, nn.BatchNorm2d)
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def forward(self, input):
+ return self.main(input)
+
+
+class PixelCNN(nn.Module):
+ def __init__(self, in_channels, out_channels, num_blocks, kernel_sizes, masked_channels):
+ super(PixelCNN, self).__init__()
+ assert num_blocks == len(kernel_sizes)
+ self.blocks = []
+ for i in range(num_blocks):
+ if i == 0:
+ block = MaskABlock(in_channels, out_channels, kernel_sizes[i], masked_channels)
+ else:
+ block = PixelCNNBlock(out_channels, kernel_sizes[i])
+ self.blocks.append(block)
+
+ self.main = nn.ModuleList(self.blocks)
+
+ self.direct_connects = []
+ for i in range(1, num_blocks - 1):
+ self.direct_connects.append(PixelCNNBlock(out_channels, kernel_sizes[i]))
+
+ self.direct_connects = nn.ModuleList(self.direct_connects)
+
+ def forward(self, input):
+ # [batch, out_channels, H, W]
+ direct_inputs = []
+ for i, layer in enumerate(self.main):
+ if i > 2:
+ direct_input = direct_inputs.pop(0)
+ direct_conncet = self.direct_connects[i - 3]
+ input = input + direct_conncet(direct_input)
+
+ input = layer(input)
+ direct_inputs.append(input)
+ assert len(direct_inputs) == 3, 'architecture error: %d' % len(direct_inputs)
+ direct_conncet = self.direct_connects[-1]
+ return input + direct_conncet(direct_inputs.pop(0))
+
+class PixelCNNDecoderV2(DecoderBase):
+ def __init__(self, args, ngpu=1, mode='large'):
+ super(PixelCNNDecoderV2, self).__init__()
+ self.ngpu = ngpu
+ self.nz = args.nz
+ self.nc = 1
+ # self.fm_latent = 4 old ??? hardcoded
+ self.fm_latent = args.latent_feature_map
+
+ self.img_latent = 28 * 28 * self.fm_latent
+ if self.nz != 0 :
+ self.z_transform = nn.Sequential(
+ nn.Linear(self.nz, self.img_latent),
+ )
+ if mode == 'small':
+ kernal_sizes = [7, 7, 7, 5, 5, 3, 3]
+ elif mode == 'large':
+ kernal_sizes = [7, 7, 7, 7, 7, 5, 5, 5, 5, 3, 3, 3, 3]
+ else:
+ raise ValueError('unknown mode: %s' % mode)
+
+ hidden_channels = 64
+ self.main = nn.Sequential(
+ PixelCNN(self.nc + self.fm_latent, hidden_channels, len(kernal_sizes), kernal_sizes, self.nc),
+ nn.Conv2d(hidden_channels, hidden_channels, 1, bias=False),
+ nn.BatchNorm2d(hidden_channels),
+ nn.ELU(),
+ nn.Conv2d(hidden_channels, self.nc, 1, bias=False),
+ nn.Sigmoid(),
+ )
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.nz != 0:
+ nn.init.xavier_uniform_(self.z_transform[0].weight)
+ nn.init.constant_(self.z_transform[0].bias, 0)
+
+ m = self.main[2]
+ assert isinstance(m, nn.BatchNorm2d)
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
+ output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
+ else:
+ output = self.main(input)
+ return output
+
+ def reconstruct_error(self, x, z):
+ eps = 1e-12
+ if type(z) == type(None):
+ batch_size, nsampels, _, _ = x.size()
+ img = x.unsqueeze(1).expand(batch_size, nsampels, *x.size()[1:])
+ else:
+ batch_size, nsampels, nz = z.size()
+ # [batch, nsamples, -1] --> [batch, nsamples, fm, H, W]
+ z = self.z_transform(z).view(batch_size, nsampels, self.fm_latent, 28, 28)
+
+ # [batch, nc, H, W] --> [batch, 1, nc, H, W] --> [batch, nsample, nc, H, W]
+ img = x.unsqueeze(1).expand(batch_size, nsampels, *x.size()[1:])
+ # [batch, nsample, nc+fm, H, W] --> [batch * nsamples, nc+fm, H, W]
+ img = torch.cat([img, z], dim=2)
+
+ img = img.view(-1, *img.size()[2:])
+
+ # [batch * nsamples, *] --> [batch, nsamples, -1]
+ recon_x = self.forward(img).view(batch_size, nsampels, -1)
+ # [batch, -1]
+ x_flat = x.view(batch_size, -1)
+ BCE = (recon_x + eps).log() * x_flat.unsqueeze(1) + (1.0 - recon_x + eps).log() * (1. - x_flat).unsqueeze(1) #cross-entropy
+ # [batch, nsamples]
+ return BCE.sum(dim=2) * -1.0
+
+ def log_probability(self, x, z):
+ bce = self.reconstruct_error(x, z)
+ return bce * -1.
+
+ def decode(self, z, deterministic):
+ '''
+
+ Args:
+ z: Tensor
+ the tensor of latent z shape=[batch, nz]
+ deterministic: boolean
+ randomly sample of decode via argmaximizing probability
+
+ Returns: Tensor
+ the tensor of decoded x shape=[batch, *]
+
+ '''
+ H = W = 28
+ batch_size, nz = z.size()
+
+ # [batch, -1] --> [batch, fm, H, W]
+ z = self.z_transform(z).view(batch_size, self.fm_latent, H, W)
+ img = Variable(z.data.new(batch_size, self.nc, H, W).zero_(), volatile=True)
+ # [batch, nc+fm, H, W]
+ img = torch.cat([img, z], dim=1)
+ for i in range(H):
+ for j in range(W):
+ # [batch, nc, H, W]
+ recon_img = self.forward(img)
+ # [batch, nc]
+ img[:, :self.nc, i, j] = torch.ge(recon_img[:, :, i, j], 0.5).float() if deterministic else torch.bernoulli(recon_img[:, :, i, j])
+ # img[:, :self.nc, i, j] = torch.bernoulli(recon_img[:, :, i, j])
+
+ # [batch, nc, H, W]
+ img_probs = self.forward(img)
+ return img[:, :self.nc], img_probs
diff --git a/modules/decoders/decoder.py b/modules/decoders/decoder.py
new file mode 100755
index 0000000..a419661
--- /dev/null
+++ b/modules/decoders/decoder.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+
+
+class DecoderBase(nn.Module):
+ """docstring for Decoder"""
+ def __init__(self):
+ super(DecoderBase, self).__init__()
+
+ def decode(self, x, z):
+
+ raise NotImplementedError
+
+ def reconstruct_error(self, x, z):
+ """reconstruction loss
+ Args:
+ x: (batch_size, *)
+ z: (batch_size, n_sample, nz)
+ Returns:
+ loss: (batch_size, n_sample). Loss
+ across different sentence and z
+ """
+
+ raise NotImplementedError
+
+ def beam_search_decode(self, z, K):
+ """beam search decoding
+ Args:
+ z: (batch_size, nz)
+ K: the beam size
+
+ Returns: List1
+ List1: the decoded word sentence list
+ """
+
+ raise NotImplementedError
+
+ def sample_decode(self, z):
+ """sampling from z
+ Args:
+ z: (batch_size, nz)
+
+ Returns: List1
+ List1: the decoded word sentence list
+ """
+
+ raise NotImplementedError
+
+ def greedy_decode(self, z):
+ """greedy decoding from z
+ Args:
+ z: (batch_size, nz)
+
+ Returns: List1
+ List1: the decoded word sentence list
+ """
+
+ raise NotImplementedError
+
+ def log_probability(self, x, z):
+ """
+ Args:
+ x: (batch_size, *)
+ z: (batch_size, n_sample, nz)
+ Returns:
+ log_p: (batch_size, n_sample).
+ log_p(x|z) across different x and z
+ """
+
+ raise NotImplementedError
+
+
+
+
\ No newline at end of file
diff --git a/modules/decoders/decoder_helper.py b/modules/decoders/decoder_helper.py
new file mode 100755
index 0000000..6b0bf31
--- /dev/null
+++ b/modules/decoders/decoder_helper.py
@@ -0,0 +1,23 @@
+
+
+class BeamSearchNode(object):
+ def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
+ '''
+ :param hiddenstate:
+ :param previousNode:
+ :param wordId:
+ :param logProb:
+ :param length:
+ '''
+ self.h = hiddenstate
+ self.prevNode = previousNode
+ self.wordid = wordId
+ self.logp = logProb
+ self.leng = length
+
+ def eval(self, alpha=1.0):
+ reward = 0
+ # Add here a function for shaping a reward
+
+ return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
+
diff --git a/modules/discriminators/__init__.py b/modules/discriminators/__init__.py
new file mode 100755
index 0000000..fe33551
--- /dev/null
+++ b/modules/discriminators/__init__.py
@@ -0,0 +1 @@
+from .discriminator_linear import *
\ No newline at end of file
diff --git a/modules/discriminators/discriminator_linear.py b/modules/discriminators/discriminator_linear.py
new file mode 100755
index 0000000..139b73a
--- /dev/null
+++ b/modules/discriminators/discriminator_linear.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
+
+
+class LinearDiscriminator_only(nn.Module):
+ """docstring for LinearDiscriminator"""
+
+ def __init__(self, args, ncluster):
+ super(LinearDiscriminator_only, self).__init__()
+ self.args = args
+ if args.IAF:
+ self.linear = nn.Linear(args.nz, ncluster)
+ else:
+ self.linear = nn.Linear(args.nz, ncluster)
+ self.loss = nn.CrossEntropyLoss(reduction="none")
+
+
+ def get_performance_with_feature(self, batch_data, batch_labels):
+ mu = batch_data
+ logits = self.linear(mu)
+ loss = self.loss(logits, batch_labels)
+
+ _, pred = torch.max(logits, dim=1)
+ correct = torch.eq(pred, batch_labels).float().sum().item()
+
+ return loss, correct
diff --git a/modules/encoders/.DS_Store b/modules/encoders/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/modules/encoders/.DS_Store differ
diff --git a/modules/encoders/__init__.py b/modules/encoders/__init__.py
new file mode 100755
index 0000000..4c78073
--- /dev/null
+++ b/modules/encoders/__init__.py
@@ -0,0 +1,5 @@
+from .enc_lstm import *
+from .enc_resnet_v2 import *
+from .flow import *
+from .enc_flow import *
+
diff --git a/modules/encoders/enc_flow.py b/modules/encoders/enc_flow.py
new file mode 100644
index 0000000..21c3382
--- /dev/null
+++ b/modules/encoders/enc_flow.py
@@ -0,0 +1,416 @@
+from .flow import *
+import torch
+from torch import nn
+import numpy as np
+import math
+from ..utils import log_sum_exp
+
+
+class IAFEncoderBase(nn.Module):
+ """docstring for EncoderBase"""
+
+ def __init__(self):
+ super(IAFEncoderBase, self).__init__()
+
+ def sample(self, input, nsamples):
+ """sampling from the encoder
+ Returns: Tensor1, Tuple
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
+ Tuple: contains the tensor mu [batch, nz] and
+ logvar[batch, nz]
+ """
+
+ z_T, log_q_z = self.forward(input, nsamples)
+
+ return z_T, log_q_z
+
+ def forward(self, x, n_sample):
+ """
+ Args:
+ x: (batch_size, *)
+
+ Returns: Tensor1, Tensor2
+ Tensor1: the mean tensor, shape (batch, nz)
+ Tensor2: the logvar tensor, shape (batch, nz)
+ """
+
+ raise NotImplementedError
+
+ def encode(self, input, args):
+ """perform the encoding and compute the KL term
+
+ Returns: Tensor1, Tensor2
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
+ Tensor2: the tenor of KL for each x with shape [batch]
+
+ """
+
+ # (batch, nsamples, nz)
+ z_T, log_q_z = self.forward(input, args.nsamples)
+
+ log_p_z = self.log_q_z_0(z=z_T) # [b s nz]
+
+ kl = log_q_z - log_p_z
+
+ # free-bit
+ if self.training and args.fb == 1 and args.target_kl > 0:
+ kl_obj = torch.mean(kl, dim=[0, 1], keepdim=True)
+ kl_obj = torch.clamp_min(kl_obj, args.target_kl)
+ kl_obj = kl_obj.expand(kl.size(0), kl.size(1), -1)
+ kl = kl_obj
+
+ return z_T, kl.sum(dim=[1, 2]) # like KL
+
+ def reparameterize(self, mu, logvar, nsamples=1):
+ """sample from posterior Gaussian family
+ Args:
+ mu: Tensor
+ Mean of gaussian distribution with shape (batch, nz)
+
+ logvar: Tensor
+ logvar of gaussian distibution with shape (batch, nz)
+
+ Returns: Tensor
+ Sampled z with shape (batch, nsamples, nz)
+ """
+ # import ipdb
+ # ipdb.set_trace()
+ batch_size, nz = mu.size()
+ std = logvar.mul(0.5).exp()
+
+ mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
+ std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
+
+ eps = torch.zeros_like(std_expd).normal_()
+
+ return mu_expd + torch.mul(eps, std_expd)
+
+ def eval_inference_dist(self, x, z, param=None):
+ """this function computes log q(z | x)
+ Args:
+ z: tensor
+ different z points that will be evaluated, with
+ shape [batch, nsamples, nz]
+ Returns: Tensor1
+ Tensor1: log q(z|x) with shape [batch, nsamples]
+ """
+
+ nz = z.size(2)
+
+ if not param:
+ mu, logvar = self.forward(x)
+ else:
+ mu, logvar = param
+
+ # if self.args.gamma <0:
+ # mu,logvar = self.trans_param(mu,logvar)
+
+ # import ipdb
+ # ipdb.set_trace()
+
+ # (batch_size, 1, nz)
+ mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
+ var = logvar.exp()
+
+ # (batch_size, nsamples, nz)
+ dev = z - mu
+
+ # (batch_size, nsamples)
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
+
+ return log_density
+
+
+class VariationalFlow(IAFEncoderBase):
+ """Approximate posterior parameterized by a flow (https://arxiv.org/abs/1606.04934)."""
+
+ def __init__(self, args, vocab_size, model_init, emb_init):
+ super().__init__()
+
+ self.ni = args.ni
+ self.nh = args.enc_nh
+ self.nz = args.nz
+ self.args = args
+
+ flow_depth = args.flow_depth
+ flow_width = args.flow_width
+
+ self.embed = nn.Embedding(vocab_size, args.ni)
+ self.lstm = nn.LSTM(input_size=args.ni, hidden_size=args.enc_nh, num_layers=1,
+ batch_first=True, dropout=0)
+
+ self.linear = nn.Linear(args.enc_nh, 4 * args.nz, bias=False)
+ modules = []
+ for _ in range(flow_depth):
+ modules.append(InverseAutoregressiveFlow(num_input=args.nz,
+ num_hidden=flow_width * args.nz, # hidden dim in MADE
+ num_context=2 * args.nz))
+ modules.append(Reverse(args.nz))
+
+ self.q_z_flow = FlowSequential(*modules)
+ self.log_q_z_0 = NormalLogProb()
+ self.softplus = nn.Softplus()
+ self.reset_parameters(model_init, emb_init)
+
+ self.BN = False
+ if self.args.gamma > 0:
+ self.BN = True
+ self.mu_bn = nn.BatchNorm1d(args.nz, eps=1e-8)
+ self.gamma = args.gamma
+ nn.init.constant_(self.mu_bn.weight, self.args.gamma)
+ nn.init.constant_(self.mu_bn.bias, 0.0)
+
+ self.DP = False
+ if self.args.p_drop > 0 and self.args.delta_rate > 0:
+ self.DP = True
+ self.p_drop = self.args.p_drop
+ self.delta_rate = self.args.delta_rate
+
+ def reset_parameters(self, model_init, emb_init):
+ for name, param in self.lstm.named_parameters():
+ # self.initializer(param)
+ if 'bias' in name:
+ nn.init.constant_(param, 0.0)
+ # model_init(param)
+ elif 'weight' in name:
+ model_init(param)
+
+ model_init(self.linear.weight)
+ emb_init(self.embed.weight)
+
+ def forward(self, input, n_samples):
+ """Return sample of latent variable and log prob."""
+ word_embed = self.embed(input)
+ _, (last_state, last_cell) = self.lstm(word_embed)
+ loc_scale, h = self.linear(last_state.squeeze(0)).chunk(2, -1)
+ loc, scale_arg = loc_scale.chunk(2, -1)
+ scale = self.softplus(scale_arg)
+
+ if self.BN:
+ ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5
+ #if ss < self.gamma:
+ self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss
+ loc = self.mu_bn(loc)
+ if self.DP and self.args.kl_weight >= self.args.drop_start:
+ var = scale ** 2
+ var = torch.dropout(var, p=self.p_drop, train=self.training)
+ var += self.delta_rate * 1.0 / (2 * math.e * math.pi)
+ scale = var ** 0.5
+
+ loc = loc.unsqueeze(1)
+ scale = scale.unsqueeze(1)
+ h = h.unsqueeze(1)
+
+ eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device)
+ z_0 = loc + scale * eps # reparameterization
+ log_q_z_0 = self.log_q_z_0(loc=loc, scale=scale, z=z_0)
+ z_T, log_q_z_flow = self.q_z_flow(z_0, context=h)
+ log_q_z = (log_q_z_0 + log_q_z_flow) # [b s nz]
+
+ if torch.sum(torch.isnan(z_T)):
+ import ipdb
+ ipdb.set_trace()
+
+ ################
+ if torch.rand(1).sum() <= 0.0005:
+ if self.BN:
+ self.mu_bn.weight
+
+ return z_T, log_q_z
+ # return z_0, log_q_z_0.sum(-1)
+
+ def infer_param(self, input):
+ word_embed = self.embed(input)
+ _, (last_state, last_cell) = self.lstm(word_embed)
+ loc_scale, h = self.linear(last_state.squeeze(0)).chunk(2, -1)
+ loc, scale_arg = loc_scale.chunk(2, -1)
+ scale = self.softplus(scale_arg)
+ # logvar = scale_arg
+
+ if self.BN:
+ ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5
+ if ss < self.gamma:
+ self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss
+ loc = self.mu_bn(loc)
+ if self.DP and self.args.kl_weight >= self.args.drop_start:
+ var = scale ** 2
+ var = torch.dropout(var, p=self.p_drop, train=self.training)
+ var += self.delta_rate * 1.0 / (2 * math.e * math.pi)
+ scale = var ** 0.5
+
+ return loc, torch.log(scale ** 2)
+
+ def learn_feature(self, input):
+ word_embed = self.embed(input)
+ _, (last_state, last_cell) = self.lstm(word_embed)
+ loc_scale, h = self.linear(last_state.squeeze(0)).chunk(2, -1)
+ loc, scale_arg = loc_scale.chunk(2, -1)
+ import ipdb
+ ipdb.set_trace()
+ if self.BN:
+ loc = self.mu_bn(loc)
+ loc = loc.unsqueeze(1)
+ h = h.unsqueeze(1)
+ z_T, log_q_z_flow = self.q_z_flow(loc, context=h)
+ return loc, z_T
+
+
+from .enc_resnet_v2 import ResNet
+
+
+class FlowResNetEncoderV2(IAFEncoderBase):
+ def __init__(self, args, ngpu=1):
+ super(FlowResNetEncoderV2, self).__init__()
+ self.ngpu = ngpu
+ self.nz = args.nz
+ self.nc = 1
+ hidden_units = 512
+ self.main = nn.Sequential(
+ ResNet(self.nc, [64, 64, 64], [2, 2, 2]),
+ nn.Conv2d(64, hidden_units, 4, 1, 0, bias=False),
+ nn.BatchNorm2d(hidden_units),
+ nn.ELU(),
+ )
+ self.linear = nn.Linear(hidden_units, 4 * self.nz)
+ self.reset_parameters()
+ self.delta_rate = args.delta_rate
+
+ self.args = args
+
+ flow_depth = args.flow_depth
+ flow_width = args.flow_width
+
+ modules = []
+ for _ in range(flow_depth):
+ modules.append(InverseAutoregressiveFlow(num_input=args.nz,
+ num_hidden=flow_width * args.nz, # hidden dim in MADE
+ num_context=2 * args.nz))
+ modules.append(Reverse(args.nz))
+
+ self.q_z_flow = FlowSequential(*modules)
+ self.log_q_z_0 = NormalLogProb()
+ self.softplus = nn.Softplus()
+
+ self.BN = False
+ if self.args.gamma > 0:
+ self.BN = True
+ self.mu_bn = nn.BatchNorm1d(args.nz, eps=1e-8)
+ self.gamma = args.gamma
+ nn.init.constant_(self.mu_bn.weight, self.args.gamma)
+ nn.init.constant_(self.mu_bn.bias, 0.0)
+
+ self.DP = False
+ if self.args.p_drop > 0 and self.args.delta_rate > 0:
+ self.DP = True
+ self.p_drop = self.args.p_drop
+ self.delta_rate = self.args.delta_rate
+
+ def reset_parameters(self):
+ for m in self.main.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.constant_(self.linear.bias, 0.0)
+
+ def forward(self, input, n_samples):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
+ output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
+ else:
+ output = self.main(input)
+ output = self.linear(output.view(output.size()[:2]))
+ loc_scale, h = output.chunk(2, -1)
+ loc, scale_arg = loc_scale.chunk(2, -1)
+ scale = self.softplus(scale_arg)
+
+ if self.BN:
+ ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5
+ #if ss < self.gamma:
+ self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss
+ loc = self.mu_bn(loc)
+
+ if self.DP and self.args.kl_weight >= self.args.drop_start:
+ var = scale ** 2
+ var = torch.dropout(var, p=self.p_drop, train=self.training)
+ var += self.delta_rate * 1.0 / (2 * math.e * math.pi)
+ scale = var ** 0.5
+
+ loc = loc.unsqueeze(1)
+ scale = scale.unsqueeze(1)
+ h = h.unsqueeze(1)
+
+ eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device)
+ z_0 = loc + scale * eps # reparameterization
+ log_q_z_0 = self.log_q_z_0(loc=loc, scale=scale, z=z_0)
+ z_T, log_q_z_flow = self.q_z_flow(z_0, context=h)
+ log_q_z = (log_q_z_0 + log_q_z_flow) # [b s nz]
+
+ if torch.sum(torch.isnan(z_T)):
+ import ipdb
+ ipdb.set_trace()
+
+ if torch.rand(1).sum() <= 0.001:
+ if self.BN:
+ self.mu_bn.weight
+
+ return z_T, log_q_z
+
+
+ def infer_param(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
+ output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
+ else:
+ output = self.main(input)
+ output = self.linear(output.view(output.size()[:2]))
+ loc_scale, h = output.chunk(2, -1)
+ loc, scale_arg = loc_scale.chunk(2, -1)
+ scale = self.softplus(scale_arg)
+
+ if self.BN:
+ ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5
+ if ss < self.gamma:
+ self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss
+ loc = self.mu_bn(loc)
+ if self.DP and self.args.kl_weight >= self.args.drop_start:
+ var = scale ** 2
+ var = torch.dropout(var, p=self.p_drop, train=self.training)
+ var += self.delta_rate * 1.0 / (2 * math.e * math.pi)
+ scale = var ** 0.5
+
+ return loc, torch.log(scale ** 2)
+
+ def learn_feature(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
+ output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
+ else:
+ output = self.main(input)
+ output = self.linear(output.view(output.size()[:2]))
+ loc_scale, h = output.chunk(2, -1)
+ loc, _ = loc_scale.chunk(2, -1)
+ if self.BN:
+ ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5
+ if ss < self.gamma:
+ self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss
+ loc = self.mu_bn(loc)
+ loc = loc.unsqueeze(1)
+ h = h.unsqueeze(1)
+ z_T, log_q_z_flow = self.q_z_flow(loc, context=h)
+ return loc, z_T
+
+
+class NormalLogProb(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z, loc=None, scale=None):
+ if loc is None:
+ loc = torch.zeros_like(z, device=z.device)
+ if scale is None:
+ scale = torch.ones_like(z, device=z.device)
+
+ var = torch.pow(scale, 2)
+ return -0.5 * torch.log(2 * np.pi * var) - torch.pow(z - loc, 2) / (2 * var)
diff --git a/modules/encoders/enc_lstm.py b/modules/encoders/enc_lstm.py
new file mode 100644
index 0000000..70dba5c
--- /dev/null
+++ b/modules/encoders/enc_lstm.py
@@ -0,0 +1,120 @@
+import math
+import torch
+import torch.nn as nn
+from .encoder import GaussianEncoderBase
+
+
+class LSTMEncoder(GaussianEncoderBase):
+ """Gaussian LSTM Encoder with constant-length batching"""
+
+ def __init__(self, args, vocab_size, model_init, emb_init):
+ super(LSTMEncoder, self).__init__()
+ self.ni = args.ni
+ self.nh = args.enc_nh
+ self.nz = args.nz
+ self.args = args
+ self.embed = nn.Embedding(vocab_size, args.ni)
+
+ self.lstm = nn.LSTM(input_size=args.ni,
+ hidden_size=args.enc_nh,
+ num_layers=1,
+ batch_first=True,
+ dropout=0)
+
+ # dimension transformation to z (mean and logvar)
+ self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False)
+
+ self.reset_parameters(model_init, emb_init)
+ self.delta_rate = args.delta_rate
+
+ def reset_parameters(self, model_init, emb_init):
+ for param in self.parameters():
+ model_init(param)
+ emb_init(self.embed.weight)
+
+ def forward(self, input):
+ """
+ Args:
+ x: (batch_size, seq_len)
+
+ Returns: Tensor1, Tensor2
+ Tensor1: the mean tensor, shape (batch, nz)
+ Tensor2: the logvar tensor, shape (batch, nz)
+ """
+ word_embed = self.embed(input)
+ _, (last_state, last_cell) = self.lstm(word_embed)
+ mean, logvar = self.linear(last_state).chunk(2, -1)
+ logvar = torch.log(torch.exp(logvar) + self.delta_rate * 1.0 / (2 * math.e * math.pi))
+
+ return mean.squeeze(0), logvar.squeeze(0)
+
+
+class GaussianLSTMEncoder(GaussianEncoderBase):
+ """Gaussian LSTM Encoder with constant-length input"""
+
+ def __init__(self, args, vocab_size, model_init, emb_init):
+ super(GaussianLSTMEncoder, self).__init__()
+ self.ni = args.ni
+ self.nh = args.enc_nh
+ self.nz = args.nz
+ self.args = args
+
+ self.embed = nn.Embedding(vocab_size, args.ni)
+
+ self.lstm = nn.LSTM(input_size=args.ni,
+ hidden_size=args.enc_nh,
+ num_layers=1,
+ batch_first=True,
+ dropout=0)
+
+ self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False)
+ self.mu_bn = nn.BatchNorm1d(args.nz)
+ self.gamma = args.gamma
+
+ self.reset_parameters(model_init, emb_init)
+ self.delta_rate = args.delta_rate
+
+ def reset_parameters(self, model_init, emb_init, reset=False):
+ if not reset:
+ nn.init.constant_(self.mu_bn.weight, self.args.gamma)
+ else:
+ print('reset bn!')
+ if self.args.gamma_train:
+ nn.init.constant_(self.mu_bn.weight, self.args.gamma)
+ else:
+ self.mu_bn.weight.fill_(self.args.gamma)
+ nn.init.constant_(self.mu_bn.bias, 0.0)
+
+ def forward(self, input):
+ """
+ Args:
+ x: (batch_size, seq_len)
+
+ Returns: Tensor1, Tensor2
+ Tensor1: the mean tensor, shape (batch, nz)
+ Tensor2: the logvar tensor, shape (batch, nz)
+ """
+
+ # (batch_size, seq_len-1, args.ni)
+ word_embed = self.embed(input)
+
+ _, (last_state, last_cell) = self.lstm(word_embed)
+ mean, logvar = self.linear(last_state).chunk(2, -1)
+ if self.args.gamma <= 0 or (mean.squeeze(0).size(0) == 1 and self.training == True):
+ mean = mean.squeeze(0)
+ else:
+ self.mu_bn.weight.requires_grad = True
+ ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5
+ #if ss < self.gamma:
+ self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss
+ mean = self.mu_bn(mean.squeeze(0))
+
+ if torch.sum(torch.isnan(mean)) or torch.sum(torch.isnan(logvar)):
+ import ipdb
+ ipdb.set_trace()
+
+ if self.args.kl_weight == 1:
+ logvar = torch.log(torch.exp(logvar) + self.delta_rate * 1.0 / (2 * math.e * math.pi))
+
+ return mean, logvar.squeeze(0)
+
diff --git a/modules/encoders/enc_resnet_v2.py b/modules/encoders/enc_resnet_v2.py
new file mode 100755
index 0000000..d4a1bab
--- /dev/null
+++ b/modules/encoders/enc_resnet_v2.py
@@ -0,0 +1,203 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from sympy import *
+from .encoder import GaussianEncoderBase
+
+"""
+A better ResNet baseline
+"""
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes,
+ kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+def deconv3x3(in_planes, out_planes, stride=1, output_padding=0):
+ "3x3 deconvolution with padding"
+ return nn.ConvTranspose2d(in_planes, out_planes,
+ kernel_size=3, stride=stride, padding=1,
+ output_padding=output_padding, bias=False)
+
+
+class ResNetBlock(nn.Module):
+ def __init__(self, inplanes, planes, stride=1):
+ super(ResNetBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.activation = nn.ELU()
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ downsample = None
+ if stride != 1 or inplanes != planes:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes),
+ )
+ self.downsample = downsample
+ self.stride = stride
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ # [batch, planes, ceil(h/stride), ceil(w/stride)]
+ residual = x if self.downsample is None else self.downsample(x)
+
+ # [batch, planes, ceil(h/stride), ceil(w/stride)]
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.activation(out)
+
+ # [batch, planes, ceil(h/stride), ceil(w/stride)]
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ out = self.activation(out + residual)
+
+ # [batch, planes, ceil(h/stride), ceil(w/stride)]
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, inplanes, planes, strides):
+ super(ResNet, self).__init__()
+ assert len(planes) == len(strides)
+
+ blocks = []
+ for i in range(len(planes)):
+ plane = planes[i]
+ stride = strides[i]
+ block = ResNetBlock(inplanes, plane, stride=stride)
+ blocks.append(block)
+ inplanes = plane
+
+ self.main = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ return self.main(x)
+
+
+class ResNetEncoderV2(GaussianEncoderBase):
+ def __init__(self, args, ngpu=1):
+ super(ResNetEncoderV2, self).__init__()
+ self.ngpu = ngpu
+ self.nz = args.nz
+ self.nc = 1
+ hidden_units = 512
+ self.main = nn.Sequential(
+ ResNet(self.nc, [64, 64, 64], [2, 2, 2]),
+ nn.Conv2d(64, hidden_units, 4, 1, 0, bias=False),
+ nn.BatchNorm2d(hidden_units),
+ nn.ELU(),
+ )
+ self.linear = nn.Linear(hidden_units, 2 * self.nz)
+ self.reset_parameters()
+ self.delta_rate = args.delta_rate
+
+ self.args = args
+
+ def reset_parameters(self):
+ for m in self.main.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.constant_(self.linear.bias, 0.0)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
+ output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
+ else:
+ output = self.main(input)
+ output = self.linear(output.view(output.size()[:2]))
+ mu, logvar = output.chunk(2, 1)
+
+ return mu, logvar
+
+
+class BNResNetEncoderV2(GaussianEncoderBase):
+ def __init__(self, args, ngpu=1):
+ super(BNResNetEncoderV2, self).__init__()
+ self.ngpu = ngpu
+ self.nz = args.nz
+ self.nc = 1
+ hidden_units = 512
+ self.main = nn.Sequential(
+ ResNet(self.nc, [64, 64, 64], [2, 2, 2]),
+ nn.Conv2d(64, hidden_units, 4, 1, 0, bias=False),
+ nn.BatchNorm2d(hidden_units),
+ nn.ELU(),
+ )
+ self.linear = nn.Linear(hidden_units, 2 * self.nz)
+ self.mu_bn = nn.BatchNorm1d(args.nz)
+ self.gamma = args.gamma
+ self.args = args
+
+ self.reset_parameters()
+ self.delta_rate = args.delta_rate
+
+ def reset_parameters(self, reset=False):
+ for m in self.main.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.constant_(self.linear.bias, 0.0)
+
+ if not reset:
+ nn.init.constant_(self.mu_bn.weight, self.gamma)
+ else:
+ print('reset bn!')
+
+ nn.init.constant_(self.mu_bn.weight, self.gamma)
+ nn.init.constant_(self.mu_bn.bias, 0.0)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
+ output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
+ else:
+ output = self.main(input)
+ output = self.linear(output.view(output.size()[:2]))
+ mu, logvar = output.chunk(2, 1)
+ if self.args.gamma > 0:
+ self.mu_bn.weight.requires_grad = True
+ ss = torch.mean(self.mu_bn.weight.data ** 2) ** 0.5
+ #if ss < self.gamma:
+ self.mu_bn.weight.data = self.mu_bn.weight.data * self.gamma / ss
+
+ mu = self.mu_bn(mu.squeeze(0))
+ else:
+ mu = mu.squeeze(0)
+
+ if self.args.kl_weight == 1:
+ logvar = torch.log(torch.exp(logvar) + self.delta_rate * 1.0 / (2 * math.e * math.pi))
+
+ if torch.rand(1).sum() <= 0.001:
+ scale = torch.exp(logvar / 2)
+ # print('gamma', self.mu_bn.weight)
+ # print('train loc mean', torch.mean(mu, dim=0))
+ # print('train scale std', torch.std(scale, dim=0))
+ # print('train scale mean', torch.mean(scale, dim=0))
+
+ return mu, logvar
+
diff --git a/modules/encoders/encoder.py b/modules/encoders/encoder.py
new file mode 100644
index 0000000..d13cbc2
--- /dev/null
+++ b/modules/encoders/encoder.py
@@ -0,0 +1,189 @@
+import math
+import torch
+import torch.nn as nn
+import random
+from ..utils import log_sum_exp
+
+
+class GaussianEncoderBase(nn.Module):
+ """docstring for EncoderBase"""
+
+ def __init__(self):
+ super(GaussianEncoderBase, self).__init__()
+
+ def forward(self, x):
+ """
+ Args:
+ x: (batch_size, *)
+
+ Returns: Tensor1, Tensor2
+ Tensor1: the mean tensor, shape (batch, nz)
+ Tensor2: the logvar tensor, shape (batch, nz)
+ """
+
+ raise NotImplementedError
+
+ def sample(self, input, nsamples):
+ """sampling from the encoder
+ Returns: Tensor1, Tuple
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
+ Tuple: contains the tensor mu [batch, nz] and
+ logvar[batch, nz]
+ """
+
+ # (batch_size, nz)
+ mu, logvar = self.forward(input)
+ # if self.args.gamma<0:
+ # mu, logvar = self.trans_param(mu, logvar)
+
+ # (batch, nsamples, nz)
+ z = self.reparameterize(mu, logvar, nsamples)
+ # if self.args.gamma <0:
+ # z=self.z_bn(z.squeeze(1)).unsqueeze(1)
+
+ return z, (mu, logvar)
+
+ def encode(self, input, args, training=True):
+ """perform the encoding and compute the KL term
+
+ Returns: Tensor1, Tensor2
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
+ Tensor2: the tenor of KL for each x with shape [batch]
+
+ """
+ nsamples = args.nsamples
+ # (batch_size, nz)
+ mu, logvar = self.forward(input)
+
+ if args.p_drop > 0 and training and args.kl_weight == 1:
+ var = logvar.exp() - args.delta_rate * 1.0 / (2 * math.e * math.pi)
+ var = torch.dropout(var, p=args.p_drop, train=training)
+ logvar = torch.log(var + args.delta_rate * 1.0 / (2 * math.e * math.pi))
+
+ # (batch, nsamples, nz)
+ z = self.reparameterize(mu, logvar, nsamples)
+
+ KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1)
+ KL = KL.sum(dim=1) #B
+
+ # if torch.sum(torch.isnan(HKL)):
+ # import ipdb
+ # ipdb.set_trace()
+
+ return z, KL
+
+ def MPD(self,mu,logvar):
+ eps = 1e-9
+ z_shape = [mu.size(1)]
+ batch_size = mu.size(0)
+ # [batch, z_shape]
+ var = logvar.exp()
+ # B [batch, batch, z_shape]
+ B = (mu.unsqueeze(1) - mu.unsqueeze(0)).pow(2).div(var.unsqueeze(0) + eps)
+ B = B.mean(0).mean(0) #z_shape
+ A = var.mean(0)
+ C = (1/(var+eps)).mean(0)
+ # if torch.max(var)>10:
+ # import ipdb
+ # ipdb.set_trace()
+ return 0.5*(B+A*C-1)
+
+ def CE(self, logvar):
+ return 0.5*(logvar.mean(dim=0) + math.log(2*math.pi*math.e))
+
+ def reparameterize(self, mu, logvar, nsamples=1):
+ """sample from posterior Gaussian family
+ Args:
+ mu: Tensor
+ Mean of gaussian distribution with shape (batch, nz)
+
+ logvar: Tensor
+ logvar of gaussian distibution with shape (batch, nz)
+
+ Returns: Tensor
+ Sampled z with shape (batch, nsamples, nz)
+ """
+
+ batch_size, nz = mu.size()
+ std = logvar.mul(0.5).exp()
+
+ mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
+ std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
+
+ eps = torch.zeros_like(std_expd).normal_()
+
+ return mu_expd + torch.mul(eps, std_expd)
+
+ def eval_inference_dist(self, x, z, param=None):
+ """this function computes log q(z | x)
+ Args:
+ z: tensor
+ different z points that will be evaluated, with
+ shape [batch, nsamples, nz]
+ Returns: Tensor1
+ Tensor1: log q(z|x) with shape [batch, nsamples]
+ """
+
+ nz = z.size(2)
+
+ if not param:
+ mu, logvar = self.forward(x)
+ else:
+ mu, logvar = param
+
+ # (batch_size, 1, nz)
+ mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
+ var = logvar.exp()
+
+ # (batch_size, nsamples, nz)
+ dev = z - mu
+
+ # (batch_size, nsamples)
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
+
+ return log_density
+
+ def calc_mi(self, x):
+ """Approximate the mutual information between x and z
+ I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
+
+ Returns: Float
+
+ """
+
+ # [x_batch, nz]
+ mu, logvar = self.forward(x)
+
+ # if self.args.gamma<0:
+ # mu, logvar = self.trans_param( mu, logvar)
+
+ x_batch, nz = mu.size()
+
+ # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
+ neg_entropy = (-0.5 * nz * math.log(2 * math.pi) - 0.5 * (1 + logvar).sum(-1)).mean()
+
+ # [z_batch, 1, nz]
+ z_samples = self.reparameterize(mu, logvar, 1)
+
+ # [1, x_batch, nz]
+ mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
+ var = logvar.exp()
+
+ # (z_batch, x_batch, nz)
+ dev = z_samples - mu
+
+ # (z_batch, x_batch)
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
+
+ # log q(z): aggregate posterior
+ # [z_batch]
+ log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch)
+
+ return (neg_entropy - log_qz.mean(-1)).item()
+
+
+
+
+
diff --git a/modules/encoders/flow.py b/modules/encoders/flow.py
new file mode 100755
index 0000000..1b7bb94
--- /dev/null
+++ b/modules/encoders/flow.py
@@ -0,0 +1,152 @@
+"""Credit: mostly based on Ilya's excellent implementation here: https://github.com/ikostrikov/pytorch-flows"""
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class InverseAutoregressiveFlow(nn.Module):
+ """Inverse Autoregressive Flows with LSTM-type update. One block.
+
+ Eq 11-14 of https://arxiv.org/abs/1606.04934
+ """
+
+ def __init__(self, num_input, num_hidden, num_context):
+ super().__init__()
+ self.made = MADE(num_input=num_input, num_output=num_input * 2,
+ num_hidden=num_hidden, num_context=num_context)
+ # init such that sigmoid(s) is close to 1 for stability
+ self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2)
+ self.sigmoid = nn.Sigmoid()
+ self.log_sigmoid = nn.LogSigmoid()
+
+ def forward(self, input, context=None):
+ m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1)
+ s = s + self.sigmoid_arg_bias
+ sigmoid = self.sigmoid(s)
+ z = sigmoid * input + (1 - sigmoid) * m
+ return z, -self.log_sigmoid(s)
+
+
+class FlowSequential(nn.Sequential):
+ """Forward pass."""
+
+ def forward(self, input, context=None):
+ total_log_prob = torch.zeros_like(input, device=input.device)
+ for block in self._modules.values():
+ input, log_prob = block(input, context)
+ total_log_prob += log_prob
+ return input, total_log_prob
+
+
+class MaskedLinear(nn.Module):
+ """Linear layer with some input-output connections masked."""
+
+ def __init__(self, in_features, out_features, mask, context_features=None, bias=True):
+ super().__init__()
+ self.linear = nn.Linear(in_features, out_features, bias)
+ self.register_buffer("mask", mask)
+ if context_features is not None:
+ self.cond_linear = nn.Linear(context_features, out_features, bias=False)
+
+ def forward(self, input, context=None):
+ output = F.linear(input, self.mask * self.linear.weight, self.linear.bias)
+ if context is None:
+ return output
+ else:
+ return output + self.cond_linear(context)
+
+
+class MADE(nn.Module):
+ """Implements MADE: Masked Autoencoder for Distribution Estimation.
+
+ Follows https://arxiv.org/abs/1502.03509
+
+ This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057).
+ """
+
+ def __init__(self, num_input, num_output, num_hidden, num_context):
+ super().__init__()
+ # m corresponds to m(k), the maximum degree of a node in the MADE paper
+ self._m = []
+ self._masks = []
+ self._build_masks(num_input, num_output, num_hidden, num_layers=3)
+ self._check_masks()
+ modules = []
+ self.input_context_net = MaskedLinear(num_input, num_hidden, self._masks[0], num_context)
+ modules.append(nn.ReLU())
+ modules.append(MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None))
+ modules.append(nn.ReLU())
+ modules.append(MaskedLinear(num_hidden, num_output, self._masks[2], context_features=None))
+ self.net = nn.Sequential(*modules)
+
+ def _build_masks(self, num_input, num_output, num_hidden, num_layers):
+ """Build the masks according to Eq 12 and 13 in the MADE paper."""
+ rng = np.random.RandomState(0)
+ # assign input units a number between 1 and D
+ self._m.append(np.arange(1, num_input + 1))
+ for i in range(1, num_layers + 1):
+ # randomly assign maximum number of input nodes to connect to
+ if i == num_layers:
+ # assign output layer units a number between 1 and D
+ m = np.arange(1, num_input + 1)
+ assert num_output % num_input == 0, "num_output must be multiple of num_input"
+ self._m.append(np.hstack([m for _ in range(num_output // num_input)]))
+ else:
+ # assign hidden layer units a number between 1 and D-1
+ self._m.append(rng.randint(1, num_input, size=num_hidden))
+ # self._m.append(np.arange(1, num_hidden + 1) % (num_input - 1) + 1)
+ if i == num_layers:
+ mask = self._m[i][None, :] > self._m[i - 1][:, None]
+ else:
+ # input to hidden & hidden to hidden
+ mask = self._m[i][None, :] >= self._m[i - 1][:, None]
+ # need to transpose for torch linear layer, shape (num_output, num_input)
+ self._masks.append(torch.from_numpy(mask.astype(np.float32).T))
+
+ def _check_masks(self):
+ """Check that the connectivity matrix between layers is lower triangular."""
+ # (num_input, num_hidden)
+ prev = self._masks[0].t()
+ for i in range(1, len(self._masks)):
+ # num_hidden is second axis
+ prev = prev @ self._masks[i].t()
+ final = prev.numpy()
+ num_input = self._masks[0].shape[1]
+ num_output = self._masks[-1].shape[0]
+ assert final.shape == (num_input, num_output)
+ if num_output == num_input:
+ assert np.triu(final).all() == 0
+ else:
+ for submat in np.split(final,
+ indices_or_sections=num_output // num_input,
+ axis=1):
+ assert np.triu(submat).all() == 0
+
+ def forward(self, input, context=None):
+ # first hidden layer receives input and context
+ hidden = self.input_context_net(input, context)
+ # rest of the network is conditioned on both input and context
+ return self.net(hidden)
+
+
+class Reverse(nn.Module):
+ """ An implementation of a reversing layer from
+ Density estimation using Real NVP
+ (https://arxiv.org/abs/1605.08803).
+
+ From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py
+ """
+
+ def __init__(self, num_input):
+ super(Reverse, self).__init__()
+ self.perm = np.array(np.arange(0, num_input)[::-1])
+ self.inv_perm = np.argsort(self.perm)
+
+ def forward(self, inputs, context=None, mode='forward'):
+ if mode == "forward":
+ return inputs[:, :, self.perm], torch.zeros_like(inputs, device=inputs.device)
+ elif mode == "inverse":
+ return inputs[:, :, self.inv_perm], torch.zeros_like(inputs, device=inputs.device)
+ else:
+ raise ValueError("Mode must be one of {forward, inverse}.")
diff --git a/modules/utils.py b/modules/utils.py
new file mode 100755
index 0000000..9e2d020
--- /dev/null
+++ b/modules/utils.py
@@ -0,0 +1,39 @@
+import torch
+
+def log_sum_exp(value, dim=None, keepdim=False):
+ """Numerically stable implementation of the operation
+ value.exp().sum(dim, keepdim).log()
+ """
+ if dim is not None:
+ m, _ = torch.max(value, dim=dim, keepdim=True)
+ value0 = value - m
+ if keepdim is False:
+ m = m.squeeze(dim)
+ return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
+ else:
+ m = torch.max(value)
+ sum_exp = torch.sum(torch.exp(value - m))
+ return m + torch.log(sum_exp)
+
+
+def generate_grid(zmin, zmax, dz, device, ndim=2):
+ """generate a 1- or 2-dimensional grid
+ Returns: Tensor, int
+ Tensor: The grid tensor with shape (k^2, 2),
+ where k=(zmax - zmin)/dz
+ int: k
+ """
+
+ if ndim == 2:
+ x = torch.arange(zmin, zmax, dz)
+ k = x.size(0)
+
+ x1 = x.unsqueeze(1).repeat(1, k).view(-1)
+ x2 = x.repeat(k)
+
+ return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k
+
+ elif ndim == 1:
+ return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device)
+
+
diff --git a/modules/vae.py b/modules/vae.py
new file mode 100755
index 0000000..d5687bb
--- /dev/null
+++ b/modules/vae.py
@@ -0,0 +1,303 @@
+import math
+import torch
+import torch.nn as nn
+
+from .utils import log_sum_exp
+
+class VAE(nn.Module):
+ """VAE with normal prior"""
+ def __init__(self, encoder, decoder, args):
+ super(VAE, self).__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+
+ self.args = args
+
+ self.nz = args.nz
+
+ loc = torch.zeros(self.nz, device=args.device)
+ scale = torch.ones(self.nz, device=args.device)
+
+ self.prior = torch.distributions.normal.Normal(loc, scale)
+
+ def encode(self, x, args,training=True):
+ """
+ Returns: Tensor1, Tensor2
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
+ Tensor2: the tenor of KL for each x with shape [batch]
+ """
+ return self.encoder.encode(x,args,training)
+
+ def encode_stats(self, x):
+ """
+ Returns: Tensor1, Tensor2
+ Tensor1: the mean of latent z with shape [batch, nz]
+ Tensor2: the logvar of latent z with shape [batch, nz]
+ """
+
+ return self.encoder(x)
+
+ def decode(self, z, strategy, K=5):
+ """generate samples from z given strategy
+
+ Args:
+ z: [batch, nsamples, nz]
+ strategy: "beam" or "greedy" or "sample"
+ K: the beam width parameter
+
+ Returns: List1
+ List1: a list of decoded word sequence
+ """
+
+ if strategy == "beam":
+ return self.decoder.beam_search_decode(z, K)
+ elif strategy == "greedy":
+ return self.decoder.greedy_decode(z)
+ elif strategy == "sample":
+ return self.decoder.sample_decode(z)
+ else:
+ raise ValueError("the decoding strategy is not supported")
+
+ def reconstruct(self, x, decoding_strategy="greedy", K=5,beta=1):
+ """reconstruct from input x
+
+ Args:
+ x: (batch, *)
+ decoding_strategy: "beam" or "greedy" or "sample"
+ K: the beam width parameter (if applicable)
+
+ Returns: List1
+ List1: a list of decoded word sequence
+ """
+ z = self.sample_from_inference(x,beta=beta).squeeze(1)
+
+ return self.decode(z, decoding_strategy, K)
+
+ def loss(self, x, kl_weight, args,training=True):
+ """
+ Args:
+ x: if the data is constant-length, x is the data tensor with
+ shape (batch, *). Otherwise x is a tuple that contains
+ the data tensor and length list
+
+ Returns: Tensor1, Tensor2, Tensor3
+ Tensor1: total loss [batch]
+ Tensor2: reconstruction loss shape [batch]
+ Tensor3: KL loss shape [batch]
+ """
+
+ z, KL= self.encode(x, args,training)
+
+ # (batch)
+ reconstruct_err = self.decoder.reconstruct_error(x, z).mean(dim=1)
+ if torch.sum(torch.isnan(reconstruct_err)):
+ import ipdb
+ ipdb.set_trace()
+ return reconstruct_err + kl_weight * KL, reconstruct_err, KL
+
+ def rc_loss(self, x, y, kl_weight, args):
+ z, KL = self.encode(x, args)
+ reconstruct_err = self.decoder.reconstruct_error(y, z).mean(dim=1)
+ return reconstruct_err + kl_weight * KL, reconstruct_err, KL
+
+ def nll_iw(self, x, nsamples, ns=100):
+ """compute the importance weighting estimate of the log-likelihood
+ Args:
+ x: if the data is constant-length, x is the data tensor with
+ shape (batch, *). Otherwise x is a tuple that contains
+ the data tensor and length list
+ nsamples: Int
+ the number of samples required to estimate marginal data likelihood
+ Returns: Tensor1
+ Tensor1: the estimate of log p(x), shape [batch]
+ """
+ # compute iw every ns samples to address the memory issue
+ # nsamples = 500, ns = 100
+ # nsamples = 500, ns = 10
+ tmp = []
+ for _ in range(int(nsamples / ns)):
+ # [batch, ns, nz]
+ # param is the parameters required to evaluate q(z|x)
+ z, param = self.encoder.sample(x, ns)
+
+ # [batch, ns]
+ log_comp_ll = self.eval_complete_ll(x, z)
+ log_infer_ll = self.eval_inference_dist(x, z, param)
+ tmp.append(log_comp_ll - log_infer_ll)
+
+ ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples)
+
+ return -ll_iw
+
+ def KL(self, x, args):
+ # _, KL = self.encode(x, 1)
+ mu, logvar = self.encoder.forward(x)
+ KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
+
+ return KL
+
+ def eval_prior_dist(self, zrange):
+ """perform grid search to calculate the true posterior
+ Args:
+ zrange: tensor
+ different z points that will be evaluated, with
+ shape (k^2, nz), where k=(zmax - zmin)/space
+ """
+
+ # (k^2)
+ return self.prior.log_prob(zrange).sum(dim=-1)
+
+ def eval_complete_ll(self, x, z):
+ """compute log p(z,x)
+ Args:
+ x: Tensor
+ input with shape [batch, seq_len]
+ z: Tensor
+ evaluation points with shape [batch, nsamples, nz]
+ Returns: Tensor1
+ Tensor1: log p(z,x) Tensor with shape [batch, nsamples]
+ """
+
+ # [batch, nsamples]
+ log_prior = self.eval_prior_dist(z)
+ log_gen = self.eval_cond_ll(x, z)
+
+ return log_prior + log_gen
+
+ def eval_cond_ll(self, x, z):
+ """compute log p(x|z)
+ """
+
+ return self.decoder.log_probability(x, z)
+
+ def eval_log_model_posterior(self, x, grid_z):
+ """perform grid search to calculate the true posterior
+ this function computes p(z|x)
+ Args:
+ grid_z: tensor
+ different z points that will be evaluated, with
+ shape (k^2, nz), where k=(zmax - zmin)/pace
+
+ Returns: Tensor
+ Tensor: the log posterior distribution log p(z|x) with
+ shape [batch_size, K^2]
+ """
+ try:
+ batch_size = x.size(0)
+ except:
+ batch_size = x[0].size(0)
+
+ # (batch_size, k^2, nz)
+ grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous()
+
+ # (batch_size, k^2)
+ log_comp = self.eval_complete_ll(x, grid_z)
+
+ # normalize to posterior
+ log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True)
+
+ return log_posterior
+
+ def sample_from_prior(self, nsamples):
+ """sampling from prior distribution
+
+ Returns: Tensor
+ Tensor: samples from prior with shape (nsamples, nz)
+ """
+ return self.prior.sample((nsamples,))
+
+ def sample_from_inference(self, x, nsamples=1,beta=1):
+ """perform sampling from inference net
+ Returns: Tensor
+ Tensor: samples from infernece nets with
+ shape (batch_size, nsamples, nz)
+ """
+ z, (mu,logvar) = self.encoder.sample(x, nsamples)
+
+ return mu # ???????
+
+ def sample_from_posterior(self, x, nsamples):
+ """perform MH sampling from model posterior
+ Returns: Tensor
+ Tensor: samples from model posterior with
+ shape (batch_size, nsamples, nz)
+ """
+
+ # use the samples from inference net as initial points
+ # for MCMC sampling. [batch_size, nsamples, nz]
+ cur = self.encoder.sample_from_inference(x, 1)
+ cur_ll = self.eval_complete_ll(x, cur)
+ total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin
+ samples = []
+ for iter_ in range(total_iter):
+ next = torch.normal(mean=cur,
+ std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std))
+ # [batch_size, 1]
+ next_ll = self.eval_complete_ll(x, next)
+ ratio = next_ll - cur_ll
+
+ accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size()))
+
+ uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_()
+
+ # [batch_size, 1]
+ mask = (uniform_t < accept_prob).float()
+
+ mask_ = mask.unsqueeze(2)
+
+ cur = mask_ * next + (1 - mask_) * cur
+ cur_ll = mask * next_ll + (1 - mask) * cur_ll
+
+ if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0:
+ samples.append(cur.unsqueeze(1))
+
+ return torch.cat(samples, dim=1)
+
+ def calc_model_posterior_mean(self, x, grid_z):
+ """compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z]
+ Args:
+ grid_z: different z points that will be evaluated, with
+ shape (k^2, nz), where k=(zmax - zmin)/pace
+ x: [batch, *]
+
+ Returns: Tensor1
+ Tensor1: the mean value tensor with shape [batch, nz]
+
+ """
+
+ # [batch, K^2]
+ log_posterior = self.eval_log_model_posterior(x, grid_z)
+ posterior = log_posterior.exp()
+
+ # [batch, nz]
+ return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1)
+
+ def calc_infer_mean(self, x):
+ """
+ Returns: Tensor1
+ Tensor1: the mean of inference distribution, with shape [batch, nz]
+ """
+
+ mean, logvar = self.encoder.forward(x)
+ # if self.args.gamma<0:
+ # mean,logvar=self.encoder.trans_param(mean,logvar)
+
+ return mean
+
+ def eval_inference_dist(self, x, z, param=None):
+ """
+ Returns: Tensor
+ Tensor: the posterior density tensor with
+ shape (batch_size, nsamples)
+ """
+ return self.encoder.eval_inference_dist(x, z, param)
+
+ def calc_mi_q(self, x):
+ """Approximate the mutual information between x and z
+ under distribution q(z|x)
+
+ Args:
+ x: [batch_size, *]. The sampled data to estimate mutual info
+ """
+
+ return self.encoder.calc_mi(x)
diff --git a/modules/vae_IAF.py b/modules/vae_IAF.py
new file mode 100644
index 0000000..4d09ca1
--- /dev/null
+++ b/modules/vae_IAF.py
@@ -0,0 +1,309 @@
+import math
+import torch
+import torch.nn as nn
+
+from .utils import log_sum_exp
+
+class VAEIAF(nn.Module):
+ """VAE with normal prior"""
+ def __init__(self, encoder, decoder, args):
+ super(VAEIAF, self).__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+
+ self.args = args
+
+ self.nz = args.nz
+
+ loc = torch.zeros(self.nz, device=args.device)
+ scale = torch.ones(self.nz, device=args.device)
+
+ self.prior = torch.distributions.normal.Normal(loc, scale)
+
+ def encode(self, x, args,training=True):
+ """
+ Returns: Tensor1, Tensor2
+ Tensor1: the tensor latent z with shape [batch, nsamples, nz]
+ Tensor2: the tenor of KL for each x with shape [batch]
+ """
+ return self.encoder.encode(x, args)
+
+ def encode_stats(self, x):
+ """
+ Returns: Tensor1, Tensor2
+ Tensor1: the mean of latent z with shape [batch, nz]
+ Tensor2: the logvar of latent z with shape [batch, nz]
+ """
+
+ return self.encoder.infer_param(x)
+
+ def decode(self, z, strategy, K=5):
+ """generate samples from z given strategy
+
+ Args:
+ z: [batch, nsamples, nz]
+ strategy: "beam" or "greedy" or "sample"
+ K: the beam width parameter
+
+ Returns: List1
+ List1: a list of decoded word sequence
+ """
+
+ if strategy == "beam":
+ return self.decoder.beam_search_decode(z, K)
+ elif strategy == "greedy":
+ return self.decoder.greedy_decode(z)
+ elif strategy == "sample":
+ return self.decoder.sample_decode(z)
+ else:
+ raise ValueError("the decoding strategy is not supported")
+
+ def reconstruct(self, x, decoding_strategy="greedy", K=5,beta=1):
+ """reconstruct from input x
+
+ Args:
+ x: (batch, *)
+ decoding_strategy: "beam" or "greedy" or "sample"
+ K: the beam width parameter (if applicable)
+
+ Returns: List1
+ List1: a list of decoded word sequence
+ """
+ z = self.sample_from_inference(x,beta=beta).squeeze(1)
+
+ return self.decode(z, decoding_strategy, K)
+
+ def loss(self, x, kl_weight, args, training=True):
+ """
+ Args:
+ x: if the data is constant-length, x is the data tensor with
+ shape (batch, *). Otherwise x is a tuple that contains
+ the data tensor and length list
+
+ Returns: Tensor1, Tensor2, Tensor3
+ Tensor1: total loss [batch]
+ Tensor2: reconstruction loss shape [batch]
+ Tensor3: KL loss shape [batch]
+ """
+
+ z, KL = self.encode(x, args, training)
+
+ # (batch)
+ reconstruct_err = self.decoder.reconstruct_error(x, z).mean(dim=1)
+ if torch.sum(torch.isnan(reconstruct_err)):
+ import ipdb
+ ipdb.set_trace()
+
+ return reconstruct_err + kl_weight * KL, reconstruct_err, KL
+
+ def rc_loss(self, x, y, kl_weight, args):
+ z, KL = self.encode(x, args)
+ reconstruct_err = self.decoder.reconstruct_error(y, z).mean(dim=1)
+ return reconstruct_err + kl_weight * KL, reconstruct_err, KL
+
+ def nll_iw(self, x, nsamples, ns=100):
+ """compute the importance weighting estimate of the log-likelihood
+ Args:
+ x: if the data is constant-length, x is the data tensor with
+ shape (batch, *). Otherwise x is a tuple that contains
+ the data tensor and length list
+ nsamples: Int
+ the number of samples required to estimate marginal data likelihood
+ Returns: Tensor1
+ Tensor1: the estimate of log p(x), shape [batch]
+ """
+ # compute iw every ns samples to address the memory issue
+ # nsamples = 500, ns = 100
+ # nsamples = 500, ns = 10
+ tmp = []
+ for _ in range(int(nsamples / ns)):
+ # [batch, ns, nz]
+ # param is the parameters required to evaluate q(z|x)
+ z, log_infer_ll = self.encoder.sample(x, ns)
+ log_infer_ll = log_infer_ll.sum(dim=-1)
+
+ # [batch, ns]
+ log_comp_ll = self.eval_complete_ll(x, z)
+
+ tmp.append(log_comp_ll - log_infer_ll)
+
+ ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples)
+
+
+ return -ll_iw
+
+ def KL(self, x, args):
+ # _, KL = self.encode(x, 1)
+ z, KL = self.encode(x, args, training=False)
+
+ return KL
+
+ def eval_prior_dist(self, zrange):
+ """perform grid search to calculate the true posterior
+ Args:
+ zrange: tensor
+ different z points that will be evaluated, with
+ shape (k^2, nz), where k=(zmax - zmin)/space
+ """
+
+ # (k^2)
+ return self.prior.log_prob(zrange).sum(dim=-1)
+
+ def eval_complete_ll(self, x, z):
+ """compute log p(z,x)
+ Args:
+ x: Tensor
+ input with shape [batch, seq_len]
+ z: Tensor
+ evaluation points with shape [batch, nsamples, nz]
+ Returns: Tensor1
+ Tensor1: log p(z,x) Tensor with shape [batch, nsamples]
+ """
+
+ # [batch, nsamples]
+ log_prior = self.eval_prior_dist(z)
+ log_gen = self.eval_cond_ll(x, z)
+
+ return log_prior + log_gen
+
+ def eval_cond_ll(self, x, z):
+ """compute log p(x|z)
+ """
+
+ return self.decoder.log_probability(x, z)
+
+ def eval_log_model_posterior(self, x, grid_z):
+ """perform grid search to calculate the true posterior
+ this function computes p(z|x)
+ Args:
+ grid_z: tensor
+ different z points that will be evaluated, with
+ shape (k^2, nz), where k=(zmax - zmin)/pace
+
+ Returns: Tensor
+ Tensor: the log posterior distribution log p(z|x) with
+ shape [batch_size, K^2]
+ """
+ try:
+ batch_size = x.size(0)
+ except:
+ batch_size = x[0].size(0)
+
+ # (batch_size, k^2, nz)
+ grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous()
+
+ # (batch_size, k^2)
+ log_comp = self.eval_complete_ll(x, grid_z)
+
+ # normalize to posterior
+ log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True)
+
+ return log_posterior
+
+ def sample_from_prior(self, nsamples):
+ """sampling from prior distribution
+
+ Returns: Tensor
+ Tensor: samples from prior with shape (nsamples, nz)
+ """
+ return self.prior.sample((nsamples,))
+
+ def sample_from_inference(self, x, nsamples=1,beta=1):
+ """perform sampling from inference net
+ Returns: Tensor
+ Tensor: samples from infernece nets with
+ shape (batch_size, nsamples, nz)
+ """
+ z, (mu,logvar) = self.encoder.sample(x, nsamples)
+ if beta==0:
+ z=mu
+
+ return mu # ???????
+
+ def sample_from_posterior(self, x, nsamples):
+ """perform MH sampling from model posterior
+ Returns: Tensor
+ Tensor: samples from model posterior with
+ shape (batch_size, nsamples, nz)
+ """
+
+ # use the samples from inference net as initial points
+ # for MCMC sampling. [batch_size, nsamples, nz]
+ cur = self.encoder.sample_from_inference(x, 1)
+ cur_ll = self.eval_complete_ll(x, cur)
+ total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin
+ samples = []
+ for iter_ in range(total_iter):
+ next = torch.normal(mean=cur,
+ std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std))
+ # [batch_size, 1]
+ next_ll = self.eval_complete_ll(x, next)
+ ratio = next_ll - cur_ll
+
+ accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size()))
+
+ uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_()
+
+ # [batch_size, 1]
+ mask = (uniform_t < accept_prob).float()
+
+ mask_ = mask.unsqueeze(2)
+
+ cur = mask_ * next + (1 - mask_) * cur
+ cur_ll = mask * next_ll + (1 - mask) * cur_ll
+
+ if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0:
+ samples.append(cur.unsqueeze(1))
+
+ return torch.cat(samples, dim=1)
+
+ def calc_model_posterior_mean(self, x, grid_z):
+ """compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z]
+ Args:
+ grid_z: different z points that will be evaluated, with
+ shape (k^2, nz), where k=(zmax - zmin)/pace
+ x: [batch, *]
+
+ Returns: Tensor1
+ Tensor1: the mean value tensor with shape [batch, nz]
+
+ """
+
+ # [batch, K^2]
+ log_posterior = self.eval_log_model_posterior(x, grid_z)
+ posterior = log_posterior.exp()
+
+ # [batch, nz]
+ return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1)
+
+ def calc_infer_mean(self, x):
+ """
+ Returns: Tensor1
+ Tensor1: the mean of inference distribution, with shape [batch, nz]
+ """
+
+ mean, logvar = self.encoder.forward(x)
+ # if self.args.gamma<0:
+ # mean,logvar=self.encoder.trans_param(mean,logvar)
+
+ return mean
+
+ def eval_inference_dist(self, x, z, param=None):
+ """
+ Returns: Tensor
+ Tensor: the posterior density tensor with
+ shape (batch_size, nsamples)
+ """
+
+
+ return self.encoder.eval_inference_dist(x, z, param)
+
+ def calc_mi_q(self, x):
+ """Approximate the mutual information between x and z
+ under distribution q(z|x)
+
+ Args:
+ x: [batch_size, *]. The sampled data to estimate mutual info
+ """
+
+ return self.encoder.calc_mi(x)
diff --git a/omniglotDataset.py b/omniglotDataset.py
new file mode 100644
index 0000000..454fc18
--- /dev/null
+++ b/omniglotDataset.py
@@ -0,0 +1,150 @@
+import os.path
+import numpy as np
+import torch
+import pickle
+
+
+
+def process_data(x,pad=15):
+ temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
+ for (img, label) in x:
+ if label in temp.keys():
+ temp[label].append(img)
+ else:
+ temp[label] = [img]
+
+ for label in temp:
+ if len(temp[label]) < pad:
+ print(label, len(temp[label]))
+ sn = pad - len(temp[label])
+ ids = np.random.choice(len(temp[label]), sn)
+ for id in ids:
+ temp[label].append(temp[label][id])
+
+ keys = list(temp.keys())
+ keys = sorted(keys)
+ ttemp = dict()
+ for key in keys:
+ a, b = key
+ if a not in ttemp:
+ ttemp[a] = []
+ ttemp[a].append(temp[key])
+ for a in ttemp:
+ ttemp[a] = np.array(ttemp[a])
+ return ttemp
+
+def feature(x,encoder,device, IAF =False):
+ dim0 = x.size(0)
+ dim1 = x.size(1)
+ tmp = x.reshape(-1, 1, 28, 28)
+ label = torch.zeros(tmp.size(0), 1)
+ tmp_data = torch.utils.data.TensorDataset(tmp, label)
+ loader = torch.utils.data.DataLoader(tmp_data, batch_size=256, shuffle=False)
+ feature = []
+ for datum in loader:
+ batch_data, _ = datum
+ batch_data = batch_data.to(device)
+ if IAF:
+ mu, zT = encoder.learn_feature(batch_data)
+ mu = torch.cat([mu,zT],dim=-1)
+ else:
+ mu, _ = encoder(batch_data)
+ feature.append(mu.detach())
+ x = torch.cat(feature, dim=0).reshape(dim0, dim1, -1)
+ return x
+
+
+class Omniglot:
+ def __init__(self, root, encoder=None, device =None, IAF=False):
+ """
+ Different from mnistNShot, the
+ :param root:
+ :param batchsz: task num
+ :param n_way:
+ :param k_shot:
+ :param k_qry:
+ :param imgsz:
+ """
+ print(root)
+ if not os.path.isfile(os.path.join(root, 'omniglot_dataset.pkl')):
+ x_train, x_test = np.load(os.path.join(root, 'omniglot.npy'),allow_pickle=True)
+
+ train_dict = process_data(x_train,pad=15)
+ test_dict = process_data(x_test,pad=5)
+ self.data_dict={}
+ for a in train_dict:
+ self.data_dict[a]={'train':train_dict[a],'test':test_dict[a]}
+ pickle.dump(self.data_dict, open(root + '/omniglot_dataset.pkl','wb'))
+ else:
+ print('load from meta_learning_dataset.pt.')
+ self.data_dict = pickle.load(open(root + '/omniglot_dataset.pkl','rb'))
+
+ if encoder:
+ print('begin learning feature')
+ for a in range(50):
+ x_train = self.data_dict[a]['train']
+ x_test = self.data_dict[a]['test']
+ print(a,x_train.shape[0])
+ x_train_s=[]
+ x_test_s =[]
+ x_train = torch.from_numpy(x_train).to(device)
+ x_test = torch.from_numpy(x_test).to(device)
+ # for _ in range(5):
+ x_train = torch.bernoulli(x_train)
+ # x_train_s.append(x_train)
+ x_test = torch.bernoulli(x_test)
+ # x_test_s.append(x_test)
+ # x_train = torch.cat(x_train_s,dim =1)
+ # x_test = torch.cat(x_test_s, dim =1 )
+
+ x_train = feature(x_train,encoder,device,IAF)
+ x_test = feature(x_test,encoder,device,IAF)
+ self.data_dict[a] = {'train': x_train, 'test': x_test}
+ print('Done!')
+
+ def load_task(self,i, trainnum=10):
+ if i < 50:
+ x_train = self.data_dict[i]['train']
+ x_test = self.data_dict[i]['test']
+ elif i==50:
+ x_train_s =[]
+ x_test_s =[]
+ for a in range(50):
+ x_train_s.append(self.data_dict[a]['train'])
+ x_test_s.append(self.data_dict[a]['test'])
+ x_train = torch.cat(x_train_s,dim =0)
+ x_test = torch.cat(x_test_s,dim=0)
+ try :
+ x_train = x_train[:,:trainnum,:]
+ NC,N = x_train.size()[:2]
+ label = torch.tensor(range(NC)).unsqueeze(1).expand(-1,N)
+ x_train = x_train.reshape(NC*N,-1)
+ l_train = label.reshape(NC * N, -1)
+
+ NC, N = x_test.size()[:2]
+ label = torch.tensor(range(NC)).unsqueeze(1).expand(-1, N)
+ x_test = x_test.reshape(NC * N, -1)
+ l_test = label.reshape(NC * N, -1)
+
+ return x_train,l_train,x_test,l_test, NC
+ except:
+ NC, N = x_train.shape[:2]
+ ds = x_train.shape[2:]
+ label = np.array(range(NC))[:,np.newaxis].repeat(N,axis=1)
+ x_train = x_train.reshape(NC * N, *ds)
+ l_train = label.reshape(NC * N, -1)
+
+ NC, N = x_test.shape[:2]
+ ds = x_test.shape[2:]
+ label = np.array(range(NC))[:,np.newaxis].repeat(N,axis=1)
+ x_test = x_test.reshape(NC * N, *ds)
+ l_test = label.reshape(NC * N, -1)
+
+ return x_train, l_train, x_test, l_test, NC
+
+if __name__ == '__main__':
+
+ root = 'data/omniglot_data/'
+ data = Omniglot(root)
+
+
diff --git a/text.py b/text.py
new file mode 100755
index 0000000..af92354
--- /dev/null
+++ b/text.py
@@ -0,0 +1,500 @@
+import sys
+import os
+import time
+import importlib
+import argparse
+
+import numpy as np
+
+import torch
+from torch import nn, optim
+
+from data import MonoTextData, VocabEntry
+from modules import VAE
+from modules import LSTMEncoder, LSTMDecoder, GaussianLSTMEncoder
+from logger import Logger
+from utils import calc_mi
+
+
+clip_grad = 5.0
+decay_epoch = 5
+lr_decay = 0.5
+max_decay = 5
+
+
+def init_config():
+ parser = argparse.ArgumentParser(description='VAE mode collapse study')
+ # model hyperparameters
+ parser.add_argument('--dataset', default='synthetic', type=str, help='dataset to use')
+ # optimization parameters
+ parser.add_argument('--momentum', type=float, default=0, help='sgd momentum')
+ parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training')
+ parser.add_argument('--iw_nsamples', type=int, default=500,
+ help='number of samples to compute importance weighted estimate')
+ # select mode
+ parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll')
+ parser.add_argument('--load_path', type=str, default='')
+ # annealing paramters
+ parser.add_argument('--warm_up', type=int, default=10, help="number of annealing epochs")
+ parser.add_argument('--kl_start', type=float, default=0.0, help="starting KL weight")
+ # these are for slurm purpose to save model
+ parser.add_argument('--jobid', type=int, default=0, help='slurm job id')
+ parser.add_argument('--taskid', type=int, default=0, help='slurm task id')
+ parser.add_argument('--device', type=str, default="cpu")
+ parser.add_argument('--delta_rate', type=float, default=1,
+ help=" coontrol the minization of the variation of latent variables")
+
+ # new
+ parser.add_argument("--target_kl", type=float, default=-1,
+ help="target kl of the free bits trick")
+ parser.add_argument('--gamma', type=float, default=1.0) # BN-VAE
+ parser.add_argument("--reset_dec", action="store_true", default=False)
+ parser.add_argument("--nz_new", type=int, default=32)
+ parser.add_argument('--p_drop', type=float, default=0.5) # p \in [0, 1]
+ parser.add_argument('--lr', type=float, default=1.0) # delta-VAE
+ args = parser.parse_args()
+
+ if 'cuda' in args.device:
+ args.cuda = True
+ else:
+ args.cuda = False
+
+ load_str = "_load" if args.load_path != "" else ""
+ load_str += '_eval' if args.eval else ''
+ save_dir = "models/%s%s/" % (args.dataset, load_str)
+
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ cw_str = '_warm%d' % args.warm_up
+ else:
+ cw_str = ''
+
+ hkl_str = 'KL%.2f' % args.kl_start
+ drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else ''
+ seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909]
+ args.seed = seed_set[args.taskid]
+
+ if args.gamma > 0:
+ gamma_str = '_gamma%.2f' % (args.gamma)
+ else:
+ gamma_str = ''
+
+ momentum_str = '_m%.2f' % args.momentum if args.momentum > 0 else ''
+ id_ = "%s_%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d%s_lr%.1f" % \
+ (args.dataset, hkl_str,
+ cw_str, load_str, gamma_str,args.delta_rate,
+ args.nz_new, drop_str,
+ args.jobid, args.taskid, args.seed, momentum_str, args.lr)
+
+ save_dir += id_
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ save_path = os.path.join(save_dir, 'model.pt')
+
+ args.save_path = save_path
+ print("save path", args.save_path)
+
+ args.log_path = os.path.join(save_dir, "log.txt")
+ print("log path", args.log_path)
+
+ # load config file into args
+ config_file = "config.config_%s" % args.dataset
+ params = importlib.import_module(config_file).params
+ args = argparse.Namespace(**vars(args), **params)
+ if args.nz != args.nz_new:
+ args.nz = args.nz_new
+
+ if 'label' in params:
+ args.label = params['label']
+ else:
+ args.label = False
+ if 'vocab_file' not in params:
+ args.vocab_file = None
+
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if args.cuda:
+ torch.cuda.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+
+ return args
+
+
+def interploation(model, data, strategy, device):
+ for i, (batch_data, sent_len) in enumerate(data.data_iter(batch_size=2, device=device,
+ batch_first=True, shuffle=False)):
+
+ z = model.sample_from_inference(batch_data).squeeze(1)
+ import ipdb
+ ipdb.set_trace()
+ z1 = z[[0], :]
+ z2 = z[[1], :]
+ ipdb.set_trace()
+ zs = []
+ for alpha in [0, 0.2, 0.4, 0.6, 0.8, 1]:
+ z = alpha * z2 + (1 - alpha) * z1
+ zs.append(z)
+ zs = torch.cat(zs, dim=0)
+ decoded_batch = model.decode(zs, strategy)
+ for sent in decoded_batch:
+ print(' '.join(sent) + '\n')
+ if i >= 10:
+ return
+
+
+def sample_from_prior(model, z, strategy, fname):
+ with open(fname, "w") as fout:
+ decoded_batch = model.decode(z, strategy)
+
+ for sent in decoded_batch:
+ fout.write(" ".join(sent) + "\n")
+
+
+def test(model, test_data_batch, mode, args, verbose=True):
+ report_kl_loss = report_kl_t_loss = report_rec_loss = 0
+ report_num_words = report_num_sents = 0
+ try:
+ print(model.encoder.theta)
+ theta = model.encoder.theta
+ print('+', torch.sum((theta > 0)))
+ print('-', torch.sum((theta < 0)))
+ except:
+ pass
+ for i in np.random.permutation(len(test_data_batch)):
+ batch_data = test_data_batch[i]
+ batch_size, sent_len = batch_data.size()
+
+ # not predict start symbol
+ report_num_words += (sent_len - 1) * batch_size
+
+ report_num_sents += batch_size
+
+ loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False)
+ loss_kl_t = model.KL(batch_data, args)
+ assert (not loss_rc.requires_grad)
+ assert (not loss_kl.requires_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+ loss_kl_t = loss_kl_t.sum()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+ report_kl_t_loss += loss_kl_t.item()
+
+ mutual_info = calc_mi(model, test_data_batch, device=args.device)
+
+ test_loss = (report_rec_loss + report_kl_loss) / report_num_sents
+
+ nll = (report_kl_t_loss + report_rec_loss) / report_num_sents
+ kl = report_kl_loss / report_num_sents
+ kl_t = report_kl_t_loss / report_num_sents
+ ppl = np.exp(nll * report_num_sents / report_num_words)
+ if verbose:
+ print('%s --- avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \
+ (mode, test_loss, report_kl_t_loss / report_num_sents, mutual_info,
+ report_rec_loss / report_num_sents, nll, ppl))
+ sys.stdout.flush()
+
+ return test_loss, nll, kl_t, ppl, mutual_info # 返回真实的kl_t 不是训练中的kl
+
+
+def calc_iwnll(model, test_data_batch, args, ns=100):
+ report_nll_loss = 0
+ report_num_words = report_num_sents = 0
+ for id_, i in enumerate(np.random.permutation(len(test_data_batch))):
+ batch_data = test_data_batch[i]
+ batch_size, sent_len = batch_data.size()
+
+ # not predict start symbol
+ report_num_words += (sent_len - 1) * batch_size
+
+ report_num_sents += batch_size
+ if id_ % (round(len(test_data_batch) / 10)) == 0:
+ print('iw nll computing %d0%%' % (id_ / (round(len(test_data_batch) / 10))))
+ sys.stdout.flush()
+
+ loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples, ns=ns)
+
+ report_nll_loss += loss.sum().item()
+
+ nll = report_nll_loss / report_num_sents
+ ppl = np.exp(nll * report_num_sents / report_num_words)
+
+ print('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
+ sys.stdout.flush()
+ return nll, ppl
+
+
+def calc_au(model, test_data_batch, delta=0.01):
+ """compute the number of active units
+ """
+ cnt = 0
+ for batch_data in test_data_batch:
+ mean, _ = model.encode_stats(batch_data)
+ if cnt == 0:
+ means_sum = mean.sum(dim=0, keepdim=True)
+ else:
+ means_sum = means_sum + mean.sum(dim=0, keepdim=True)
+ cnt += mean.size(0)
+
+ # (1, nz)
+ mean_mean = means_sum / cnt
+
+ cnt = 0
+ for batch_data in test_data_batch:
+ mean, _ = model.encode_stats(batch_data)
+ if cnt == 0:
+ var_sum = ((mean - mean_mean) ** 2).sum(dim=0)
+ else:
+ var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0)
+ cnt += mean.size(0)
+
+ # (nz)
+ au_var = var_sum / (cnt - 1)
+
+ return (au_var >= delta).sum().item(), au_var
+
+
+def main(args):
+ class uniform_initializer(object):
+ def __init__(self, stdv):
+ self.stdv = stdv
+
+ def __call__(self, tensor):
+ nn.init.uniform_(tensor, -self.stdv, self.stdv)
+
+ if args.cuda:
+ print('using cuda')
+
+ print(args)
+
+ opt_dict = {"not_improved": 0, "lr": args.lr, "best_loss": 1e4}
+ if args.vocab_file is not None:
+ print(args.vocab_file)
+ vocab = {}
+ with open(args.vocab_file) as fvocab:
+ for i, line in enumerate(fvocab):
+ vocab[line.strip()] = i
+
+ vocab = VocabEntry(vocab)
+ train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab)
+ else:
+ train_data = MonoTextData(args.train_data, label=args.label)
+
+ vocab = train_data.vocab
+ vocab_size = len(vocab)
+
+ val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
+ test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)
+
+ print('Train data: %d samples' % len(train_data))
+ print('finish reading datasets, vocab size is %d' % len(vocab))
+ print('dropped sentences: %d' % train_data.dropped)
+ sys.stdout.flush()
+
+ log_niter = (len(train_data) // args.batch_size) // 10
+
+ model_init = uniform_initializer(0.01)
+ emb_init = uniform_initializer(0.1)
+
+ args.device = torch.device(args.device)
+ device = args.device
+
+ if args.gamma > 0 and args.enc_type == 'lstm':
+ encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
+ args.enc_nh = args.dec_nh
+ elif args.gamma == 0 and args.enc_type == 'lstm':
+ encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
+ args.enc_nh = args.dec_nh
+ else:
+ raise ValueError("the specified encoder type is not supported")
+
+ decoder = LSTMDecoder(args, vocab, model_init, emb_init)
+
+ vae = VAE(encoder, decoder, args).to(device)
+
+ if args.load_path:
+ loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device))
+ vae.load_state_dict(loaded_state_dict, strict=False)
+ print("%s loaded" % args.load_path)
+
+ if args.reset_dec:
+ if args.gamma > 0:
+ vae.encoder.reset_parameters(model_init, emb_init, reset=True)
+ print("\n-------reset decoder-------\n")
+ vae.decoder.reset_parameters(model_init, emb_init)
+
+ if args.eval:
+ print('begin evaluation')
+ args.kl_weight = 1
+ vae.load_state_dict(torch.load(args.load_path, map_location=torch.device(device)))
+ vae.eval()
+ with torch.no_grad():
+ test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+ test(vae, test_data_batch, "TEST", args)
+ au, au_var = calc_au(vae, test_data_batch)
+ print("%d active units" % au)
+ # print(au_var)
+ test_data_batch = test_data.create_data_batch(batch_size=1,
+ device=device,
+ batch_first=True)
+ calc_iwnll(vae, test_data_batch, args)
+ return
+
+ enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum)
+ dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum)
+ opt_dict['lr'] = args.lr
+
+ iter_ = decay_cnt = 0
+ best_loss = 1e4
+ vae.train()
+ start = time.time()
+ kl_weight = args.kl_start
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ anneal_rate = (1.0 - args.kl_start) / (
+ args.warm_up * (len(train_data) / args.batch_size)) # kl_start ==0 时 anneal_rate==0
+ else:
+ anneal_rate = 0
+
+ train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+
+ val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+
+ test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+ for epoch in range(args.epochs):
+ report_kl_loss = report_rec_loss = 0
+ report_num_words = report_num_sents = 0
+ for i in np.random.permutation(len(train_data_batch)): # len(train_data_batch)
+ batch_data = train_data_batch[i]
+ batch_size, sent_len = batch_data.size()
+ if batch_data.size(0) < 16:
+ continue
+ # not predict start symbol
+ report_num_words += (sent_len - 1) * batch_size
+
+ report_num_sents += batch_size
+
+ # kl_weight = 1.0
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ kl_weight = min(1.0, kl_weight + anneal_rate)
+ else:
+ kl_weight = 1.0
+
+ args.kl_weight = kl_weight
+
+ enc_optimizer.zero_grad()
+ dec_optimizer.zero_grad()
+ loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args)
+ # loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples)
+
+ loss = loss.mean(dim=0)
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+
+ enc_optimizer.step()
+ dec_optimizer.step()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+
+ if iter_ % log_niter == 0:
+ train_loss = (report_rec_loss + report_kl_loss) / report_num_sents
+ if epoch == 0:
+ vae.eval()
+ with torch.no_grad():
+ mi = calc_mi(vae, val_data_batch, device=device)
+ au, _ = calc_au(vae, val_data_batch)
+ vae.train()
+
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f,' \
+ 'au %d, time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
+ report_rec_loss / report_num_sents, au, time.time() - start))
+ else:
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, recon: %.4f,' \
+ 'time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
+ report_rec_loss / report_num_sents, time.time() - start))
+
+ sys.stdout.flush()
+
+ report_rec_loss = report_kl_loss = 0
+ report_num_words = report_num_sents = 0
+
+ iter_ += 1
+
+ print('kl weight %.4f' % args.kl_weight)
+
+ vae.eval()
+ with torch.no_grad():
+ loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
+ au, au_var = calc_au(vae, val_data_batch)
+ print("%d active units" % au)
+ # print(au_var)
+
+ if loss < best_loss:
+ print('update best loss')
+ best_loss = loss
+ torch.save(vae.state_dict(), args.save_path)
+ # torch.save(vae.state_dict(), args.save_path)
+ if loss > opt_dict["best_loss"]:
+ opt_dict["not_improved"] += 1
+ if opt_dict["not_improved"] >= decay_epoch and epoch >= 15:
+ opt_dict["best_loss"] = loss
+ opt_dict["not_improved"] = 0
+ opt_dict["lr"] = opt_dict["lr"] * lr_decay
+ vae.load_state_dict(torch.load(args.save_path))
+ print('new lr: %f' % opt_dict["lr"])
+ decay_cnt += 1
+ enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum)
+ dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum)
+
+ else:
+ opt_dict["not_improved"] = 0
+ opt_dict["best_loss"] = loss
+
+ if decay_cnt == max_decay:
+ break
+
+ if epoch % args.test_nepoch == 0:
+ with torch.no_grad():
+ loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
+
+ vae.train()
+
+ # compute importance weighted estimate of log p(x)
+ vae.load_state_dict(torch.load(args.save_path))
+
+ vae.eval()
+ with torch.no_grad():
+ loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
+ au, au_var = calc_au(vae, test_data_batch)
+ print("%d active units" % au)
+ # print(au_var)
+
+ test_data_batch = test_data.create_data_batch(batch_size=1,
+ device=device,
+ batch_first=True)
+ with torch.no_grad():
+ calc_iwnll(vae, test_data_batch, args)
+ return args
+
+
+if __name__ == '__main__':
+ args = init_config()
+ if not args.eval:
+ sys.stdout = Logger(args.log_path)
+ main(args)
diff --git a/text_IAF.py b/text_IAF.py
new file mode 100644
index 0000000..79a30f4
--- /dev/null
+++ b/text_IAF.py
@@ -0,0 +1,447 @@
+import sys
+import os
+import time
+import importlib
+import argparse
+import numpy as np
+import torch
+from torch import nn, optim
+from data import MonoTextData, VocabEntry
+from modules import VAEIAF as VAE
+from modules import VariationalFlow, LSTMDecoder
+from logger import Logger
+from utils import calc_mi, calc_au
+
+clip_grad = 5.0
+decay_epoch = 5
+lr_decay = 0.5
+max_decay = 5
+
+
+def init_config():
+ parser = argparse.ArgumentParser(description='VAE-IAF mode collapse study')
+
+ # model hyperparameters
+ parser.add_argument('--dataset', default='synthetic', type=str, help='dataset to use')
+
+ # optimization parameters
+ parser.add_argument('--momentum', type=float, default=0, help='sgd momentum')
+ parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training')
+ parser.add_argument('--iw_nsamples', type=int, default=500,
+ help='number of samples to compute importance weighted estimate')
+
+ # select mode
+ parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll')
+ parser.add_argument('--load_path', type=str, default='')
+
+ # annealing paramters
+ parser.add_argument('--warm_up', type=int, default=100, help="number of annealing epochs")
+ parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight")
+
+ # these are for slurm purpose to save model
+ parser.add_argument('--jobid', type=int, default=0, help='slurm job id')
+ parser.add_argument('--taskid', type=int, default=0, help='slurm task id')
+ parser.add_argument('--device', type=str, default="cpu")
+ parser.add_argument('--delta_rate', type=float, default=1.0,
+ help=" coontrol the minization of the variation of latent variables")
+
+ parser.add_argument('--gamma', type=float, default=0.8) # BN-VAE
+
+ parser.add_argument("--nz_new", type=int, default=32)
+
+ parser.add_argument('--p_drop', type=float, default=0.3) # p \in [0, 1]
+ parser.add_argument('--lr', type=float, default=1.0) # delta-VAE
+
+ parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow")
+ parser.add_argument('--flow_width', type=int, default=2, help="width of flow")
+
+ parser.add_argument("--fb", type=int, default=1,
+ help="0: no fb; 1: fb;")
+
+ parser.add_argument("--target_kl", type=float, default=0.0,
+ help="target kl of the free bits trick")
+
+ parser.add_argument('--drop_start', type=float, default=1.0, help="starting KL weight")
+
+ args = parser.parse_args()
+
+ if 'cuda' in args.device:
+ args.cuda = True
+ else:
+ args.cuda = False
+
+ load_str = "_load" if args.load_path != "" else ""
+ save_dir = "models/%s%s/" % (args.dataset, load_str)
+
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ cw_str = '_warm%d' % args.warm_up + '_%.2f' % args.kl_start
+ else:
+ cw_str = ''
+
+ if args.fb == 0:
+ fb_str = ""
+ elif args.fb in [1, 2]:
+ fb_str = "_fb%d_tr%.2f" % (args.fb, args.target_kl)
+
+ else:
+ fb_str = ''
+
+ drop_str = '_drop%.2f' % args.p_drop if args.p_drop != 0 else ''
+ if 1.0 > args.drop_start > 0 and args.p_drop != 0:
+ drop_str += '_start%.2f' % args.drop_start
+
+ seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909]
+ args.seed = seed_set[args.taskid]
+
+ if args.gamma > 0:
+ gamma_str = '_gamma%.2f' % (args.gamma)
+ else:
+ gamma_str = ''
+
+ if args.flow_depth > 0:
+ fd_str = '_fd%d_fw%d' % (args.flow_depth, args.flow_width)
+
+ momentum_str = '_m%.2f' % args.momentum if args.momentum > 0 else ''
+ id_ = "%s%s%s%s%s%s_dr%.2f_nz%d%s_%d_%d_%d%s_lr%.1f_IAF" % \
+ (args.dataset, cw_str, load_str, gamma_str, fb_str, fd_str,
+ args.delta_rate, args.nz_new, drop_str,
+ args.jobid, args.taskid, args.seed, momentum_str, args.lr)
+ save_dir += id_
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ save_path = os.path.join(save_dir, 'model.pt')
+
+ args.save_path = save_path
+ print("save path", args.save_path)
+
+ args.log_path = os.path.join(save_dir, "log.txt")
+ print("log path", args.log_path)
+
+ # load config file into args
+ config_file = "config.config_%s" % args.dataset
+ params = importlib.import_module(config_file).params
+ args = argparse.Namespace(**vars(args), **params)
+ if args.nz != args.nz_new:
+ args.nz = args.nz_new
+
+ if 'label' in params:
+ args.label = params['label']
+ else:
+ args.label = False
+ if 'vocab_file' not in params:
+ args.vocab_file = None
+
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if args.cuda:
+ torch.cuda.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+
+ return args
+
+
+def test(model, test_data_batch, mode, args, verbose=True):
+ report_kl_loss = report_kl_t_loss = report_rec_loss = 0
+ report_num_words = report_num_sents = 0
+
+ for i in np.random.permutation(len(test_data_batch)):
+ batch_data = test_data_batch[i]
+ batch_size, sent_len = batch_data.size()
+
+ # not predict start symbol
+ report_num_words += (sent_len - 1) * batch_size
+
+ report_num_sents += batch_size
+
+ loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, args, training=False)
+ loss_kl_t = model.KL(batch_data, args)
+ assert (not loss_rc.requires_grad)
+ assert (not loss_kl.requires_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+ loss_kl_t = loss_kl_t.sum()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+ report_kl_t_loss += loss_kl_t.item()
+
+ mutual_info = calc_mi(model, test_data_batch, device=args.device)
+
+ test_loss = (report_rec_loss + report_kl_loss) / report_num_sents
+
+ nll = (report_kl_t_loss + report_rec_loss) / report_num_sents
+ kl = report_kl_loss / report_num_sents
+ kl_t = report_kl_t_loss / report_num_sents
+ ppl = np.exp(nll * report_num_sents / report_num_words)
+ if verbose:
+ print('%s --- avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \
+ (mode, test_loss, report_kl_t_loss / report_num_sents, mutual_info,
+ report_rec_loss / report_num_sents, nll, ppl))
+ sys.stdout.flush()
+
+ return test_loss, nll, kl_t, ppl, mutual_info # 返回真实的kl_t 不是训练中的kl
+
+
+def calc_iwnll(model, test_data_batch, args, ns=100):
+ report_nll_loss = 0
+ report_num_words = report_num_sents = 0
+ for id_, i in enumerate(np.random.permutation(len(test_data_batch))):
+ batch_data = test_data_batch[i]
+ batch_size, sent_len = batch_data.size()
+
+ # not predict start symbol
+ report_num_words += (sent_len - 1) * batch_size
+
+ report_num_sents += batch_size
+ if id_ % (round(len(test_data_batch) / 10)) == 0:
+ print('iw nll computing %d0%%' % (id_ / (round(len(test_data_batch) / 10))))
+ sys.stdout.flush()
+
+ loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples, ns=ns)
+
+ report_nll_loss += loss.sum().item()
+
+ nll = report_nll_loss / report_num_sents
+ ppl = np.exp(nll * report_num_sents / report_num_words)
+
+ print('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
+ sys.stdout.flush()
+ return nll, ppl
+
+
+def main(args):
+ class uniform_initializer(object):
+ def __init__(self, stdv):
+ self.stdv = stdv
+
+ def __call__(self, tensor):
+ nn.init.uniform_(tensor, -self.stdv, self.stdv)
+
+ if args.cuda:
+ print('using cuda')
+
+ print(args)
+
+ opt_dict = {"not_improved": 0, "lr": args.lr, "best_loss": 1e4}
+
+ if args.vocab_file is not None:
+ print(args.vocab_file)
+ vocab = {}
+ with open(args.vocab_file) as fvocab:
+ for i, line in enumerate(fvocab):
+ vocab[line.strip()] = i
+ vocab = VocabEntry(vocab)
+ train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab)
+ else:
+ train_data = MonoTextData(args.train_data, label=args.label)
+
+ vocab = train_data.vocab
+
+ vocab_size = len(vocab)
+
+ val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
+ test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)
+
+ print('Train data: %d samples' % len(train_data))
+ print('finish reading datasets, vocab size is %d' % len(vocab))
+ print('dropped sentences: %d' % train_data.dropped)
+ sys.stdout.flush()
+
+ log_niter = (len(train_data) // args.batch_size) // 10
+
+ model_init = uniform_initializer(0.01)
+ emb_init = uniform_initializer(0.1)
+
+ args.device = torch.device(args.device)
+ device = args.device
+
+ encoder = VariationalFlow(args, vocab_size, model_init, emb_init)
+ args.enc_nh = args.dec_nh
+
+ decoder = LSTMDecoder(args, vocab, model_init, emb_init)
+
+ vae = VAE(encoder, decoder, args).to(device)
+
+ if args.load_path:
+ loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device))
+ vae.load_state_dict(loaded_state_dict, strict=False)
+ print("%s loaded" % args.load_path)
+
+
+ if args.eval:
+ print('begin evaluation')
+ vae.load_state_dict(torch.load(args.load_path, map_location=torch.device(device)))
+ vae.eval()
+ with torch.no_grad():
+ test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+ test(vae, test_data_batch, "TEST", args)
+ au, au_var = calc_au(vae, test_data_batch)
+ print("%d active units" % au)
+ # print(au_var)
+ test_data_batch = test_data.create_data_batch(batch_size=1,
+ device=device,
+ batch_first=True)
+ calc_iwnll(vae, test_data_batch, args)
+ return
+
+ enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum)
+ dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum)
+ opt_dict['lr'] = args.lr
+
+ iter_ = decay_cnt = 0
+ best_loss = 1e4
+ vae.train()
+ start = time.time()
+
+ kl_weight = args.kl_start
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ anneal_rate = (1.0 - args.kl_start) / (
+ args.warm_up * (len(train_data) / args.batch_size)) # kl_start ==0 时 anneal_rate==0
+ else:
+ anneal_rate = 0
+
+ train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+
+ val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+
+ test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+ for epoch in range(args.epochs):
+ report_kl_loss = report_rec_loss = 0
+ report_num_words = report_num_sents = 0
+ for i in np.random.permutation(len(train_data_batch)): # len(train_data_batch)
+ batch_data = train_data_batch[i]
+ batch_size, sent_len = batch_data.size()
+ if batch_data.size(0) < 16:
+ continue
+
+ # not predict start symbol
+ report_num_words += (sent_len - 1) * batch_size
+
+ report_num_sents += batch_size
+
+ # kl_weight = 1.0
+ if args.warm_up > 0 and args.kl_start < 1.0:
+ kl_weight = min(1.0, kl_weight + anneal_rate)
+ else:
+ kl_weight = 1.0
+
+ args.kl_weight = kl_weight
+
+ enc_optimizer.zero_grad()
+ dec_optimizer.zero_grad()
+ if args.fb == 0 or args.fb == 1:
+ loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, args)
+
+ loss = loss.mean(dim=0)
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)
+
+ loss_rc = loss_rc.sum()
+ loss_kl = loss_kl.sum()
+
+ enc_optimizer.step()
+ dec_optimizer.step()
+
+ report_rec_loss += loss_rc.item()
+ report_kl_loss += loss_kl.item()
+
+ if iter_ % log_niter == 0:
+ train_loss = (report_rec_loss + report_kl_loss) / report_num_sents
+ if epoch == 0:
+ vae.eval()
+ with torch.no_grad():
+ mi = calc_mi(vae, val_data_batch, device=device)
+ au, _ = calc_au(vae, val_data_batch)
+ vae.train()
+
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, mi: %.4f, recon: %.4f,' \
+ 'au %d, time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
+ report_rec_loss / report_num_sents, au, time.time() - start))
+ else:
+ print('epoch: %d, iter: %d, avg_loss: %.4f, kl/H(z|x): %.4f, recon: %.4f,' \
+ 'time elapsed %.2fs' %
+ (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
+ report_rec_loss / report_num_sents, time.time() - start))
+
+ sys.stdout.flush()
+
+ report_rec_loss = report_kl_loss = 0
+ report_num_words = report_num_sents = 0
+
+ iter_ += 1
+
+ print('kl weight %.4f' % args.kl_weight)
+
+ vae.eval()
+ with torch.no_grad():
+ loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
+ au, au_var = calc_au(vae, val_data_batch)
+ print("%d active units" % au)
+ # print(au_var)
+
+ if loss < best_loss:
+ print('update best loss')
+ best_loss = loss
+ torch.save(vae.state_dict(), args.save_path)
+ # torch.save(vae.state_dict(), args.save_path)
+ if loss > opt_dict["best_loss"]:
+ opt_dict["not_improved"] += 1
+ if opt_dict["not_improved"] >= decay_epoch and epoch >= 15:
+ opt_dict["best_loss"] = loss
+ opt_dict["not_improved"] = 0
+ opt_dict["lr"] = opt_dict["lr"] * lr_decay
+ vae.load_state_dict(torch.load(args.save_path))
+ print('new lr: %f' % opt_dict["lr"])
+ decay_cnt += 1
+ enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum)
+ dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum)
+
+ else:
+ opt_dict["not_improved"] = 0
+ opt_dict["best_loss"] = loss
+
+ if decay_cnt == max_decay:
+ break
+
+ if epoch % args.test_nepoch == 0:
+ with torch.no_grad():
+ loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
+
+ vae.train()
+
+ # compute importance weighted estimate of log p(x)
+ vae.load_state_dict(torch.load(args.save_path))
+
+ vae.eval()
+ with torch.no_grad():
+ loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
+ au, au_var = calc_au(vae, test_data_batch)
+ print("%d active units" % au)
+ # print(au_var)
+
+ test_data_batch = test_data.create_data_batch(batch_size=1,
+ device=device,
+ batch_first=True)
+ with torch.no_grad():
+ calc_iwnll(vae, test_data_batch, args)
+ return args
+
+
+if __name__ == '__main__':
+ args = init_config()
+ if not args.eval:
+ sys.stdout = Logger(args.log_path)
+ main(args)
diff --git a/text_ss.py b/text_ss.py
new file mode 100644
index 0000000..06505aa
--- /dev/null
+++ b/text_ss.py
@@ -0,0 +1,423 @@
+import os
+import time
+import importlib
+import argparse
+
+import numpy as np
+
+import torch
+from torch import nn, optim
+# import swats
+
+from collections import defaultdict
+
+from data import MonoTextData, VocabEntry
+from modules import VAE,LinearDiscriminator_only
+from modules import GaussianLSTMEncoder, LSTMEncoder, LSTMDecoder, VariationalFlow
+
+from exp_utils import create_exp_dir
+from utils import uniform_initializer
+
+# old parameters
+clip_grad = 5.0
+decay_epoch = 2
+lr_decay = 0.5
+max_decay = 5
+
+# Junxian's new parameters
+# clip_grad = 1.0
+# decay_epoch = 5
+# lr_decay = 0.8
+# max_decay = 10
+
+logging = None
+
+
+def init_config():
+ parser = argparse.ArgumentParser(description='VAE mode collapse study')
+ parser.add_argument('--gamma', type=float, default=0.0)
+ # model hyperparameters
+ parser.add_argument('--dataset', default='yelp', type=str, help='dataset to use')
+ # optimization parameters
+ parser.add_argument('--momentum', type=float, default=0, help='sgd momentum')
+ parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="sgd", help='sgd momentum')
+
+ parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training')
+ parser.add_argument('--iw_nsamples', type=int, default=500,
+ help='number of samples to compute importance weighted estimate')
+
+ # select mode
+ parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll')
+ parser.add_argument('--load_path', type=str,
+ default='short_yelp_aggressive0_hs1.00_warm100_0_0_783435.pt') # TODO: 设定load_path
+
+ # annealing paramters
+ parser.add_argument('--warm_up', type=int, default=100,
+ help="number of annealing epochs. warm_up=0 means not anneal")
+ parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight")
+
+ # output directory
+ parser.add_argument('--exp_dir', default=None, type=str,
+ help='experiment directory.')
+ parser.add_argument("--save_ckpt", type=int, default=0,
+ help="save checkpoint every epoch before this number")
+ parser.add_argument("--save_latent", type=int, default=0)
+
+ # new
+ parser.add_argument("--reset_dec", action="store_true", default=True)
+ parser.add_argument("--load_best_epoch", type=int, default=0)
+ parser.add_argument("--lr", type=float, default=1.)
+
+ parser.add_argument("--batch_size", type=int, default=16,
+ help="target kl of the free bits trick")
+ parser.add_argument("--epochs", type=int, default=100,
+ help="number of epochs")
+ parser.add_argument("--update_every", type=int, default=1,
+ help="target kl of the free bits trick")
+ parser.add_argument("--num_label", type=int, default=100,
+ help="target kl of the free bits trick")
+ parser.add_argument("--freeze_enc", action="store_true",
+ default=True) # True-> freeze the parameters of vae.encoder
+ parser.add_argument("--discriminator", type=str, default="linear")
+
+ parser.add_argument('--taskid', type=int, default=0, help='slurm task id')
+ parser.add_argument('--device', type=str, default="cuda:0")
+ parser.add_argument('--delta_rate', type=float, default=0.0,
+ help=" coontrol the minization of the variation of latent variables")
+
+ parser.add_argument('--p_drop', type=float, default=0)
+ parser.add_argument('--IAF', action='store_true', default=False)
+ parser.add_argument('--flow_depth', type=int, default=2, help="depth of flow")
+ parser.add_argument('--flow_width', type=int, default=60, help="width of flow")
+
+ args = parser.parse_args()
+
+ # set args.cuda
+ if 'cuda' in args.device:
+ args.cuda = True
+ else:
+ args.cuda = False
+
+ # set seeds
+ seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909]
+ args.seed = seed_set[args.taskid]
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+
+ if args.cuda:
+ torch.cuda.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+
+ # load config file into args
+ config_file = "config.config_%s" % args.dataset
+ if args.num_label == 10:
+ params = importlib.import_module(config_file).params_ss_10
+ elif args.num_label == 100:
+ params = importlib.import_module(config_file).params_ss_100
+ elif args.num_label == 500:
+ params = importlib.import_module(config_file).params_ss_500
+ elif args.num_label == 1000:
+ params = importlib.import_module(config_file).params_ss_1000
+ elif args.num_label == 2000:
+ params = importlib.import_module(config_file).params_ss_2000
+ elif args.num_label == 10000:
+ params = importlib.import_module(config_file).params_ss_10000
+
+ args = argparse.Namespace(**vars(args), **params)
+
+ load_str = "_load" if args.load_path != "" else ""
+
+ opt_str = "_adam" if args.opt == "adam" else "_sgd"
+ nlabel_str = "_nlabel{}".format(args.num_label)
+ freeze_str = "_freeze" if args.freeze_enc else ""
+
+ if len(args.load_path.split("/")) > 2:
+ load_path_str = args.load_path.split("/")[2]
+ else:
+ load_path_str = args.load_path.split("/")[1]
+
+ model_str = "_{}".format(args.discriminator)
+ # set load and save paths
+ if args.exp_dir is None:
+ args.exp_dir = "models/exp_{}{}_ss_ft/{}{}{}{}{}".format(args.dataset,
+ load_str, load_path_str, model_str, opt_str,
+ nlabel_str, freeze_str)
+
+ if len(args.load_path) <= 0 and args.eval:
+ args.load_path = os.path.join(args.exp_dir, 'model.pt')
+
+ args.save_path = os.path.join(args.exp_dir, 'model.pt')
+
+ # set args.label
+ if 'label' in params:
+ args.label = params['label']
+ else:
+ args.label = False
+
+ args.kl_weight = 1
+
+ return args
+
+
+def test(model, test_data_batch, test_labels_batch, mode, args, verbose=True):
+ global logging
+
+ report_correct = report_loss = 0
+ report_num_sents = 0
+ for i in np.random.permutation(len(test_data_batch)):
+ batch_data = test_data_batch[i]
+ batch_labels = test_labels_batch[i]
+ batch_labels = [int(x) for x in batch_labels]
+
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=args.device)
+
+ batch_size = batch_data.size(0)
+
+ # not predict start symbol
+ report_num_sents += batch_size
+
+ loss, correct = model.get_performance_with_feature(batch_data, batch_labels)
+
+ loss = loss.sum()
+
+ report_loss += loss.item()
+ report_correct += correct
+
+ test_loss = report_loss / report_num_sents
+ acc = report_correct / report_num_sents
+
+ if verbose:
+ logging('%s --- avg_loss: %.4f, acc: %.4f' % \
+ (mode, test_loss, acc))
+ # sys.stdout.flush()
+
+ return test_loss, acc
+
+def main(args):
+ global logging
+ logging = create_exp_dir(args.exp_dir, scripts_to_save=[])
+
+ if args.cuda:
+ logging('using cuda')
+ logging(str(args))
+
+ opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}
+
+ vocab = {}
+ with open(args.vocab_file) as fvocab:
+ for i, line in enumerate(fvocab):
+ vocab[line.strip()] = i
+
+ vocab = VocabEntry(vocab)
+
+ train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab)
+
+ vocab_size = len(vocab)
+
+ val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
+ test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)
+
+ logging('Train data: %d samples' % len(train_data))
+ logging('finish reading datasets, vocab size is %d' % len(vocab))
+ logging('dropped sentences: %d' % train_data.dropped)
+ # sys.stdout.flush()
+
+ log_niter = max(1, (len(train_data) // (args.batch_size * args.update_every)) // 10)
+
+ model_init = uniform_initializer(0.01)
+ emb_init = uniform_initializer(0.1)
+
+ # device = torch.device("cuda" if args.cuda else "cpu")
+ # device = "cuda" if args.cuda else "cpu"
+ device = args.device
+
+ if args.gamma > 0 and args.enc_type == 'lstm' and not args.IAF:
+ encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
+ args.enc_nh = args.dec_nh
+ elif args.gamma == 0 and args.enc_type == 'lstm' and not args.IAF:
+ encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
+ args.enc_nh = args.dec_nh
+ elif args.IAF:
+ encoder = VariationalFlow(args,vocab_size, model_init, emb_init)
+ args.enc_nh = args.dec_nh
+ else:
+ raise ValueError("the specified encoder type is not supported")
+
+ decoder = LSTMDecoder(args, vocab, model_init, emb_init)
+
+ vae = VAE(encoder, decoder, args).to(device)
+ vae.eval()
+
+ if args.load_path:
+ loaded_state_dict = torch.load(args.load_path, map_location=torch.device(device))
+ vae.load_state_dict(loaded_state_dict)
+ logging("%s loaded" % args.load_path)
+
+ try:
+ print('theta', vae.encoder.theta)
+ except:
+ pass
+ if args.discriminator == "linear":
+ discriminator = LinearDiscriminator_only(args, args.ncluster).to(device)
+ # elif args.discriminator == "mlp":
+ # discriminator = MLPDiscriminator(args, vae.encoder).to(device)
+
+ if args.opt == "sgd":
+ optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum)
+ opt_dict['lr'] = args.lr
+ elif args.opt == "adam":
+ optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
+ # optimizer = swats.SWATS(discriminator.parameters(), lr=0.001)
+ opt_dict['lr'] = 0.001
+ else:
+ raise ValueError("optimizer not supported")
+
+ iter_ = decay_cnt = 0
+ best_loss = 1e4
+ # best_kl = best_nll = best_ppl = 0
+ # pre_mi = 0
+ discriminator.train()
+ start = time.time()
+
+ train_data_batch, train_labels_batch = train_data.create_data_batch_labels(batch_size=args.batch_size,
+ device=device,
+ batch_first=True)
+
+ val_data_batch, val_labels_batch = val_data.create_data_batch_labels(batch_size=128,
+ device=device,
+ batch_first=True)
+
+ test_data_batch, test_labels_batch = test_data.create_data_batch_labels(batch_size=128,
+ device=device,
+ batch_first=True)
+ #
+ def learn_feature(data_batch,labels_batch):
+ feature = []
+ label = []
+ for i in np.random.permutation(len(data_batch)):
+ batch_data = data_batch[i]
+ batch_labels = labels_batch[i]
+ batch_data = batch_data.to(device)
+ batch_size = batch_data.size(0)
+ if args.IAF:
+ loc, zT = vae.encoder.learn_feature(batch_data)
+ # mu = torch.cat([loc, zT], dim=-1)
+ mu=zT
+ mu = mu.squeeze(1)
+ else:
+ mu, logvar = vae.encoder(batch_data)
+ feature.append(mu.detach())
+ label.append(batch_labels)
+ return feature,label
+
+ train_data_batch, train_labels_batch = learn_feature(train_data_batch, train_labels_batch)
+ val_data_batch, val_labels_batch = learn_feature(val_data_batch, val_labels_batch)
+ test_data_batch,test_labels_batch = learn_feature(test_data_batch,test_labels_batch)
+
+ acc_cnt = 1
+ acc_loss = 0.
+ for epoch in range(args.epochs):
+ report_loss = 0
+ report_correct = report_num_sents = 0
+ acc_batch_size = 0
+ optimizer.zero_grad()
+ for i in np.random.permutation(len(train_data_batch)):
+ batch_data = train_data_batch[i]
+ if batch_data.size(0) < 2:
+ continue
+ batch_labels = train_labels_batch[i]
+ batch_labels = [int(x) for x in batch_labels]
+
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device)
+
+ batch_size = batch_data.size(0)
+
+ # not predict start symbol
+ report_num_sents += batch_size
+ acc_batch_size += batch_size
+
+ # (batch_size)
+ loss, correct = discriminator.get_performance_with_feature(batch_data, batch_labels)
+
+ acc_loss = acc_loss + loss.sum()
+
+ if acc_cnt % args.update_every == 0:
+ acc_loss = acc_loss / acc_batch_size
+ acc_loss.backward()
+
+ torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ acc_cnt = 0
+ acc_loss = 0
+ acc_batch_size = 0
+
+ acc_cnt += 1
+ report_loss += loss.sum().item()
+ report_correct += correct
+
+ if iter_ % log_niter == 0:
+ train_loss = report_loss / report_num_sents
+
+ iter_ += 1
+
+ # logging('lr {}'.format(opt_dict["lr"]))
+ # print(report_num_sents)
+ discriminator.eval()
+
+ with torch.no_grad():
+ loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args)
+ # print(au_var)
+
+ if loss < best_loss:
+ logging('update best loss')
+ best_loss = loss
+ best_acc = acc
+ print(args.save_path)
+ torch.save(discriminator.state_dict(), args.save_path)
+
+ if loss > opt_dict["best_loss"]:
+ opt_dict["not_improved"] += 1
+ if opt_dict["not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
+ opt_dict["best_loss"] = loss
+ opt_dict["not_improved"] = 0
+ opt_dict["lr"] = opt_dict["lr"] * lr_decay
+ discriminator.load_state_dict(torch.load(args.save_path))
+ logging('new lr: %f' % opt_dict["lr"])
+ decay_cnt += 1
+ if args.opt == "sgd":
+ optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum)
+ opt_dict['lr'] = opt_dict["lr"]
+ elif args.opt == "adam":
+ optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"])
+ opt_dict['lr'] = opt_dict["lr"]
+ else:
+ raise ValueError("optimizer not supported")
+
+ else:
+ opt_dict["not_improved"] = 0
+ opt_dict["best_loss"] = loss
+
+ if decay_cnt == max_decay:
+ break
+
+ if epoch % args.test_nepoch == 0:
+ with torch.no_grad():
+ loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)
+
+ discriminator.train()
+
+ # compute importance weighted estimate of log p(x)
+ discriminator.load_state_dict(torch.load(args.save_path))
+ discriminator.eval()
+
+ with torch.no_grad():
+ loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)
+ # print(au_var)
+
+
+if __name__ == '__main__':
+ args = init_config()
+ main(args)
diff --git a/utils.py b/utils.py
new file mode 100755
index 0000000..2735134
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,340 @@
+import numpy as np
+import os, sys
+import torch
+from torch import nn, optim
+import subprocess
+from modules import VariationalFlow, FlowResNetEncoderV2
+# from image_modules import GaussianEncoder
+class uniform_initializer(object):
+ def __init__(self, stdv):
+ self.stdv = stdv
+
+ def __call__(self, tensor):
+ nn.init.uniform_(tensor, -self.stdv, self.stdv)
+
+
+class xavier_normal_initializer(object):
+ def __call__(self, tensor):
+ nn.init.xavier_normal_(tensor)
+
+
+
+def calc_iwnll(model, test_data_batch, args, ns=100):
+ report_nll_loss = 0
+ report_num_words = report_num_sents = 0
+ print("iw nll computing ", end="")
+ for id_, i in enumerate(np.random.permutation(len(test_data_batch))):
+ batch_data = test_data_batch[i]
+ batch_size, sent_len = batch_data.size()
+
+ # not predict start symbol
+ report_num_words += (sent_len - 1) * batch_size
+
+ report_num_sents += batch_size
+ if id_ % (round(len(test_data_batch) / 20)) == 0:
+ print('%d%% ' % (id_ / (round(len(test_data_batch) / 20)) * 5), end="")
+ sys.stdout.flush()
+
+ loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples, ns=ns)
+
+ report_nll_loss += loss.sum().item()
+
+ print()
+ sys.stdout.flush()
+
+ nll = report_nll_loss / report_num_sents
+ ppl = np.exp(nll * report_num_sents / report_num_words)
+
+ return nll, ppl
+
+
+def calc_mi(model, test_data_batch, device='cpu'):
+ # calc_mi_v3
+ import math
+ from modules.utils import log_sum_exp
+
+ mi = 0
+ num_examples = 0
+
+ mu_batch_list, logvar_batch_list = [], []
+ neg_entropy = 0.
+ for batch_data in test_data_batch:
+
+ if isinstance(batch_data, list):
+ batch_data = batch_data[0]
+
+ batch_data = batch_data.to(device)
+
+ if isinstance(model.encoder, VariationalFlow) \
+ or isinstance(model.encoder, FlowResNetEncoderV2):
+ # or isinstance(model.encoder, GaussianEncoder):
+ mu, logvar = model.encode_stats(batch_data)
+ else:
+ mu, logvar = model.encoder.forward(batch_data)
+ x_batch, nz = mu.size()
+ ##print(x_batch, end=' ')
+ num_examples += x_batch
+
+ # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
+ neg_entropy += (-0.5 * nz * math.log(2 * math.pi) - 0.5 * (1 + logvar).sum(-1)).sum().item()
+ mu_batch_list += [mu.cpu()]
+ logvar_batch_list += [logvar.cpu()]
+
+ neg_entropy = neg_entropy / num_examples
+
+ num_examples = 0
+ log_qz = 0.
+ for i in range(len(mu_batch_list)):
+ ###############
+ # get z_samples
+ ###############
+ mu, logvar = mu_batch_list[i].to(device), logvar_batch_list[i].to(device)
+
+ # [z_batch, 1, nz]
+ if hasattr(model.encoder, 'reparameterize'):
+ z_samples = model.encoder.reparameterize(mu, logvar, 1)
+ else:
+ z_samples = model.encoder.gaussian_enc.reparameterize(mu, logvar, 1)
+ z_samples = z_samples.view(-1, 1, nz)
+ num_examples += z_samples.size(0)
+
+ ###############
+ # compute density
+ ###############
+ # [1, x_batch, nz]
+
+ indices = np.arange(len(mu_batch_list))
+ mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).to(device)
+ logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).to(device)
+ x_batch, nz = mu.size()
+
+ mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
+ var = logvar.exp()
+
+ # (z_batch, x_batch, nz)
+ dev = z_samples - mu
+
+ # (z_batch, x_batch)89
+ log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
+ 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
+
+ # log q(z): aggregate posterior
+ # [z_batch]
+ log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1)
+
+ log_qz /= num_examples
+ mi = neg_entropy - log_qz
+
+ return mi
+
+
+def calc_au(model, test_data_batch, delta=0.01):
+ """compute the number of active units
+ """
+ cnt = 0
+ for batch_data in test_data_batch:
+ mean, _ = model.encode_stats(batch_data)
+ if cnt == 0:
+ means_sum = mean.sum(dim=0, keepdim=True)
+ else:
+ means_sum = means_sum + mean.sum(dim=0, keepdim=True)
+ cnt += mean.size(0)
+
+ # (1, nz)
+ mean_mean = means_sum / cnt
+
+ cnt = 0
+ for batch_data in test_data_batch:
+ mean, _ = model.encode_stats(batch_data)
+ if cnt == 0:
+ var_sum = ((mean - mean_mean) ** 2).sum(dim=0)
+ else:
+ var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0)
+ cnt += mean.size(0)
+
+ # (nz)
+ au_var = var_sum / (cnt - 1)
+
+ return (au_var >= delta).sum().item(), au_var
+
+
+def sample_sentences(vae, vocab, device, num_sentences):
+ global logging
+
+ vae.eval()
+ sampled_sents = []
+ for i in range(num_sentences):
+ z = vae.sample_from_prior(1)
+ z = z.view(1, 1, -1)
+ start = vocab.word2id['']
+ # START = torch.tensor([[[start]]])
+ START = torch.tensor([[start]])
+ end = vocab.word2id['']
+ START = START.to(device)
+ z = z.to(device)
+ vae.eval()
+ sentence = vae.decoder.sample_text(START, z, end, device)
+ decoded_sentence = vocab.decode_sentence(sentence)
+ sampled_sents.append(decoded_sentence)
+ for i, sent in enumerate(sampled_sents):
+ logging(i, ":", ' '.join(sent))
+
+
+def visualize_latent(args, epoch, vae, device, test_data):
+ nsamples = 1
+
+ with open(os.path.join(args.exp_dir, f'synthetic_latent_{epoch}.txt'), 'w') as f:
+ test_data_batch, test_label_batch = test_data.create_data_batch_labels(batch_size=args.batch_size,
+ device=device, batch_first=True)
+ for i in range(len(test_data_batch)):
+ batch_data = test_data_batch[i]
+ batch_label = test_label_batch[i]
+ batch_size, sent_len = batch_data.size()
+ samples, _ = vae.encoder.encode(batch_data, nsamples)
+ for i in range(batch_size):
+ for j in range(nsamples):
+ sample = samples[i, j, :].cpu().detach().numpy().tolist()
+ f.write(batch_label[i] + '\t' + ' '.join([str(val) for val in sample]) + '\n')
+
+
+cnames = {
+'aliceblue': '#F0F8FF',
+'antiquewhite': '#FAEBD7',
+'aqua': '#00FFFF',
+'aquamarine': '#7FFFD4',
+'azure': '#F0FFFF',
+'beige': '#F5F5DC',
+'bisque': '#FFE4C4',
+'black': '#000000',
+'blanchedalmond': '#FFEBCD',
+'blue': '#0000FF',
+'blueviolet': '#8A2BE2',
+'brown': '#A52A2A',
+'burlywood': '#DEB887',
+'cadetblue': '#5F9EA0',
+'chartreuse': '#7FFF00',
+'chocolate': '#D2691E',
+'coral': '#FF7F50',
+'cornflowerblue': '#6495ED',
+'cornsilk': '#FFF8DC',
+'crimson': '#DC143C',
+'cyan': '#00FFFF',
+'darkblue': '#00008B',
+'darkcyan': '#008B8B',
+'darkgoldenrod': '#B8860B',
+'darkgray': '#A9A9A9',
+'darkgreen': '#006400',
+'darkkhaki': '#BDB76B',
+'darkmagenta': '#8B008B',
+'darkolivegreen': '#556B2F',
+'darkorange': '#FF8C00',
+'darkorchid': '#9932CC',
+'darkred': '#8B0000',
+'darksalmon': '#E9967A',
+'darkseagreen': '#8FBC8F',
+'darkslateblue': '#483D8B',
+'darkslategray': '#2F4F4F',
+'darkturquoise': '#00CED1',
+'darkviolet': '#9400D3',
+'deeppink': '#FF1493',
+'deepskyblue': '#00BFFF',
+'dimgray': '#696969',
+'dodgerblue': '#1E90FF',
+'firebrick': '#B22222',
+'floralwhite': '#FFFAF0',
+'forestgreen': '#228B22',
+'fuchsia': '#FF00FF',
+'gainsboro': '#DCDCDC',
+'ghostwhite': '#F8F8FF',
+'gold': '#FFD700',
+'goldenrod': '#DAA520',
+'gray': '#808080',
+'green': '#008000',
+'greenyellow': '#ADFF2F',
+'honeydew': '#F0FFF0',
+'hotpink': '#FF69B4',
+'indianred': '#CD5C5C',
+'indigo': '#4B0082',
+'ivory': '#FFFFF0',
+'khaki': '#F0E68C',
+'lavender': '#E6E6FA',
+'lavenderblush': '#FFF0F5',
+'lawngreen': '#7CFC00',
+'lemonchiffon': '#FFFACD',
+'lightblue': '#ADD8E6',
+'lightcoral': '#F08080',
+'lightcyan': '#E0FFFF',
+'lightgoldenrodyellow': '#FAFAD2',
+'lightgreen': '#90EE90',
+'lightgray': '#D3D3D3',
+'lightpink': '#FFB6C1',
+'lightsalmon': '#FFA07A',
+'lightseagreen': '#20B2AA',
+'lightskyblue': '#87CEFA',
+'lightslategray': '#778899',
+'lightsteelblue': '#B0C4DE',
+'lightyellow': '#FFFFE0',
+'lime': '#00FF00',
+'limegreen': '#32CD32',
+'linen': '#FAF0E6',
+'magenta': '#FF00FF',
+'maroon': '#800000',
+'mediumaquamarine': '#66CDAA',
+'mediumblue': '#0000CD',
+'mediumorchid': '#BA55D3',
+'mediumpurple': '#9370DB',
+'mediumseagreen': '#3CB371',
+'mediumslateblue': '#7B68EE',
+'mediumspringgreen': '#00FA9A',
+'mediumturquoise': '#48D1CC',
+'mediumvioletred': '#C71585',
+'midnightblue': '#191970',
+'mintcream': '#F5FFFA',
+'mistyrose': '#FFE4E1',
+'moccasin': '#FFE4B5',
+'navajowhite': '#FFDEAD',
+'navy': '#000080',
+'oldlace': '#FDF5E6',
+'olive': '#808000',
+'olivedrab': '#6B8E23',
+'orange': '#FFA500',
+'orangered': '#FF4500',
+'orchid': '#DA70D6',
+'palegoldenrod': '#EEE8AA',
+'palegreen': '#98FB98',
+'paleturquoise': '#AFEEEE',
+'palevioletred': '#DB7093',
+'papayawhip': '#FFEFD5',
+'peachpuff': '#FFDAB9',
+'peru': '#CD853F',
+'pink': '#FFC0CB',
+'plum': '#DDA0DD',
+'powderblue': '#B0E0E6',
+'purple': '#800080',
+'red': '#FF0000',
+'rosybrown': '#BC8F8F',
+'royalblue': '#4169E1',
+'saddlebrown': '#8B4513',
+'salmon': '#FA8072',
+'sandybrown': '#FAA460',
+'seagreen': '#2E8B57',
+'seashell': '#FFF5EE',
+'sienna': '#A0522D',
+'silver': '#C0C0C0',
+'skyblue': '#87CEEB',
+'slateblue': '#6A5ACD',
+'slategray': '#708090',
+'snow': '#FFFAFA',
+'springgreen': '#00FF7F',
+'steelblue': '#4682B4',
+'tan': '#D2B48C',
+'teal': '#008080',
+'thistle': '#D8BFD8',
+'tomato': '#FF6347',
+'turquoise': '#40E0D0',
+'violet': '#EE82EE',
+'wheat': '#F5DEB3',
+'white': '#FFFFFF',
+'whitesmoke': '#F5F5F5',
+'yellow': '#FFFF00',
+'yellowgreen': '#9ACD32'}
\ No newline at end of file