#! /usr/bin/env python

import tensorflow as tf
import numpy as np
import os
import time
import datetime
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from builddata_ecir import *
from capsuleNet_SEARCH17 import CapsE

np.random.seed(1234)
tf.set_random_seed(1234)

# Parameters
# ==================================================
parser = ArgumentParser("CapsE", formatter_class=ArgumentDefaultsHelpFormatter, conflict_handler='resolve')

parser.add_argument("--data", default="./data/", help="Data sources.")
parser.add_argument("--run_folder", default="./", help="Data sources.")
parser.add_argument("--name", default="SEARCH17", help="Name of the dataset.")

parser.add_argument("--embedding_dim", default=200, type=int, help="Dimensionality of character embedding (fixed: 200)")
parser.add_argument("--filter_size", default=1, type=int, help="Comma-separated filter sizes (default: '3,4,5')")
parser.add_argument("--num_filters", default=400, type=int, help="Number of filters per filter size (default: 128)")
parser.add_argument("--learning_rate", default=0.00001, type=float, help="Learning rate")
parser.add_argument("--batch_size", default=128, type=int, help="Batch Size")
parser.add_argument("--neg_ratio", default=1.0, help="Number of negative triples generated by positive (default: 1.0)")
parser.add_argument("--useInitialization", default=True, type=bool, help="Using the pretrained embeddings")
parser.add_argument("--num_epochs", default=100, type=int, help="Number of training epochs")
parser.add_argument("--savedEpochs", default=10, type=int, help="")
parser.add_argument("--allow_soft_placement", default=True, type=bool, help="Allow device soft device placement")
parser.add_argument("--log_device_placement", default=False, type=bool, help="Log placement of ops on devices")
parser.add_argument("--model_name", default='search17model', help="")
parser.add_argument("--useConstantInit", action='store_true')

parser.add_argument('--iter_routing', default=1, type=int, help='number of iterations in routing algorithm')
parser.add_argument('--num_outputs_secondCaps', default=1, type=int, help='')
parser.add_argument('--vec_len_secondCaps', default=10, type=int, help='')

args = parser.parse_args()
print(args)
# Load data
# Load data
print("Loading data...")

train_triples, train_rank_triples, train_val_triples, valid_triples, valid_rank_triples, valid_val_triples, \
            test_triples, test_rank_triples, test_val_triples, query_indexes, user_indexes, doc_indexes, \
            indexes_query, indexes_user, indexes_doc = build_data_ecir()
data_size = len(train_triples)
train_batch = Batch_Loader_ecir(train_triples, train_val_triples, batch_size=args.batch_size)

assert args.embedding_dim % 200 == 0

pretrained_query = init_dataset_ecir(args.data + args.name + '/query2vec.200.init')
pretrained_user = init_dataset_ecir(args.data + args.name + '/user2vec.200.init')
pretrained_doc = init_dataset_ecir(args.data + args.name + '/doc2vec.200.init')

print("Using pre-trained initialization.")

lstEmbedQuery = assignEmbeddings(pretrained_query, query_indexes)
lstEmbedUser = assignEmbeddings(pretrained_user, user_indexes)
lstEmbedDoc = assignEmbeddings(pretrained_doc, doc_indexes)

lstEmbedQuery = np.array(lstEmbedQuery, dtype=np.float32)
lstEmbedUser = np.array(lstEmbedUser, dtype=np.float32)
lstEmbedDoc = np.array(lstEmbedDoc, dtype=np.float32)

print("Loading data... finished!")

# Training
# ==================================================
with tf.Graph().as_default():
    session_conf = tf.ConfigProto(allow_soft_placement=args.allow_soft_placement, log_device_placement=args.log_device_placement)
    session_conf.gpu_options.allow_growth = True
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        global_step = tf.Variable(0, name="global_step", trainable=False)
        capse = CapsE(sequence_length=3,
                      batch_size=20 * args.batch_size,
                      initialization=[lstEmbedQuery, lstEmbedUser, lstEmbedDoc],
                            embedding_size=200,
                            filter_size=args.filter_size,
                            num_filters=args.num_filters,
                            iter_routing=args.iter_routing,
                            num_outputs_secondCaps=args.num_outputs_secondCaps,
                            vec_len_secondCaps=args.vec_len_secondCaps,
                            useConstantInit=args.useConstantInit
                            )

        # Define Training procedure
        #optimizer = tf.contrib.opt.NadamOptimizer(1e-3)
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        #optimizer = tf.train.RMSPropOptimizer(learning_rate=args.learning_rate)
        #optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        grads_and_vars = optimizer.compute_gradients(capse.total_loss)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

        out_dir = os.path.abspath(os.path.join(args.run_folder, "runs_CapsE_SEARCH17", args.model_name))
        print("Writing to {}\n".format(out_dir))

        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                capse.input_x: x_batch,
                capse.input_y: y_batch
            }
            _, step, loss = sess.run([train_op, global_step, capse.total_loss], feed_dict)
            return loss

        # Predict function to predict scores for test data
        def predict(x_batch, y_batch):
            feed_dict = {
                capse.input_x: x_batch,
                capse.input_y: y_batch,
            }
            scores = sess.run([capse.predictions], feed_dict)
            return scores


        def test_prediction(x_batch, y_batch, lstOriginalRank):

            new_x_batch = np.concatenate(x_batch)
            new_y_batch = np.concatenate(y_batch, axis=0)

            while len(new_x_batch) % (args.batch_size * 20) != 0:
                new_x_batch = np.append(new_x_batch, np.array([new_x_batch[-1]]), axis=0)
                new_y_batch = np.append(new_y_batch, np.array([new_y_batch[-1]]), axis=0)

            results = []
            listIndexes = range(0, len(new_x_batch), 20 * args.batch_size)
            for tmpIndex in range(len(listIndexes) - 1):
                results = np.append(results,
                                    predict(new_x_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]],
                                            new_y_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]]))
            results = np.append(results,
                                predict(new_x_batch[listIndexes[-1]:], new_y_batch[listIndexes[-1]:]))

            lstresults = []
            _start = 0
            for tmp in lstOriginalRank:
                _end = _start + len(tmp)
                lstsorted = np.argsort(results[_start:_end])
                lstresults.append(np.where(lstsorted == 0)[0] + 1)
                _start = _end

            return lstresults


        wri = open(checkpoint_prefix + '.cls.' + '.txt', 'w')

        lstvalid_mrr = []
        lsttest_mrr = []
        num_batches_per_epoch = int((data_size - 1) / (args.batch_size)) + 1
        for epoch in range(args.num_epochs):
            for batch_num in range(num_batches_per_epoch):
                x_batch, y_batch = train_batch()
                train_step(x_batch, y_batch)
                current_step = tf.train.global_step(sess, global_step)

            valid_results = test_prediction(valid_triples, valid_val_triples, valid_rank_triples)
            test_results = test_prediction(test_triples, test_val_triples, test_rank_triples)
            valid_mrr = computeMRR(valid_results)
            test_mrr = computeMRR(test_results)
            test_p1 = computeP1(test_results)
            lstvalid_mrr.append(valid_mrr)
            lsttest_mrr.append([test_mrr, test_p1])

            wri.write("epoch " + str(epoch) + ": " + str(valid_mrr) + " " + str(test_mrr) + " " + str(test_p1) + "\n")

        index_valid_max = np.argmax(lstvalid_mrr)
        wri.write("\n--------------------------\n")
        wri.write("\nBest mrr in valid at epoch " + str(index_valid_max) + ": " + str(lstvalid_mrr[index_valid_max]) + "\n")
        wri.write("\nMRR and P1 in test: " + str(lsttest_mrr[index_valid_max][0]) + " " + str(lsttest_mrr[index_valid_max][1]) + "\n")
        wri.close()