# Code from https://github.com/pfnet-research/hyperbolic_wrapped_distribution/blob/master/lib/models/embedding.py


import copy
import wandb
import torch
import argparse
import importlib
import numpy as np
from math import ceil
from torch.optim import Adagrad
from torch.nn import functional as F
from torch.utils.data import DataLoader

from tasks.WordNet import Dataset, evaluation


class LRScheduler():
    def __init__(self, optimizer, lr, c, n_burnin_steps):
        self.optimizer = optimizer
        self.lr = lr
        self.n_burnin_steps = n_burnin_steps
        self.c = c
        self.n_steps = 0

    def step_and_update_lr(self):
        self._update_learning_rate()
        self.optimizer.step()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def _update_learning_rate(self):
        self.n_steps += 1
        if self.n_steps <= self.n_burnin_steps:
            lr = self.lr / self.c
        else:
            lr = self.lr

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr


if __name__ == "__main__":
    parser = argparse.ArgumentParser(add_help=True)
    parser.add_argument('--data_dir', type=str, default='data/')
    parser.add_argument('--data_type', type=str, default='noun')
    parser.add_argument('--n_negatives', type=int, default=1)
    parser.add_argument('--latent_dim', type=int)
    parser.add_argument('--batch_size', type=int, default=50000)
    parser.add_argument('--lr', type=float, default=0.6)
    parser.add_argument('--n_epochs', type=int, default=10000)
    parser.add_argument('--dist', type=str, choices=['EuclideanNormal', 'IsotropicHWN', 'DiagonalHWN', 'RoWN', 'FullHWN'])
    parser.add_argument('--initial_sigma', type=float, default=0.01)
    parser.add_argument('--bound', type=float, default=37)
    parser.add_argument('--train_samples', type=int, default=1)
    parser.add_argument('--test_samples', type=int, default=100)
    parser.add_argument('--eval_interval', type=int, default=1000)
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--c', type=float, default=40)
    parser.add_argument('--burnin_epochs', type=int, default=100)
    parser.add_argument('--device', type=str, default='cuda:0')
    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_default_tensor_type(torch.DoubleTensor)

    dataset = Dataset(args)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)

    dist_module = importlib.import_module(f'distributions.{args.dist}')
    model = getattr(dist_module, 'EmbeddingLayer')(args, dataset.n_words).to(args.device)
    dist_fn = getattr(dist_module, 'Distribution')

    optimizer = Adagrad(model.parameters(), lr=args.lr)
    n_batches = int(ceil(len(dataset) / args.batch_size))
    n_burnin_steps = args.burnin_epochs * n_batches
    lr_scheduler = LRScheduler(optimizer, args.lr, args.c, n_burnin_steps)

    best_model = copy.deepcopy(model)
    best_score = None 

    wandb.init(project='RoWN')
    wandb.run.name = 'wordnet'
    wandb.config.update(args)
    for epoch in range(1, args.n_epochs + 1):
        total_loss, total_kl_target, total_kl_negative = 0., 0., 0.
        total_diff = 0.
        n_batches = 0
        model.train()

        for x in loader:
            for param in model.parameters():
                param.grad = None
            x = x.cuda()
            mean, covar = model(x)
            dist_anchor = dist_fn(mean[:, 0, :], covar[:, 0, :])
            dist_target = dist_fn(mean[:, 1, :], covar[:, 1, :])
            dist_negative = dist_fn(mean[:, 2, :], covar[:, 2, :])

            z = dist_anchor.rsample(args.train_samples)
            log_prob_anchor = dist_anchor.log_prob(z)
            log_prob_target = dist_target.log_prob(z)
            log_prob_negative = dist_negative.log_prob(z)
            kl_target = (log_prob_anchor - log_prob_target).mean(dim=0)
            kl_negative = (log_prob_anchor - log_prob_negative).mean(dim=0)
            
            loss = F.relu(args.bound + kl_target - kl_negative).mean()
            loss.backward()
            lr_scheduler.step_and_update_lr()
           
            total_loss += loss.item() * kl_target.size(0)
            total_kl_target += kl_target.sum().item()
            total_kl_negative += kl_negative.sum().item()
            total_diff += (kl_target - kl_negative).sum().item()
            n_batches += kl_target.size(0)

        if best_score is None or best_score > total_loss:
            best_score = total_loss
            best_model = copy.deepcopy(model)

        print(f"Epoch {epoch:8d} | Total loss: {total_loss / n_batches:.3f} | KL Target: {total_kl_target / n_batches:.3f} | KL Negative: {total_kl_negative / n_batches:.3f}")
        wandb.log({
            'epoch': epoch,
            'train_loss': total_loss / n_batches,
            'train_kl_target': total_kl_target / n_batches,
            'train_kl_negative': total_kl_negative / n_batches
        })
        
        if epoch % args.eval_interval == 0 or epoch == args.n_epochs:
            best_model.eval()
            rank, ap = evaluation(args, best_model, dataset, dist_fn)
            print(f"===========> Mean rank: {rank} | MAP: {ap}")
            wandb.log({
                'epoch': epoch,
                'rank': rank,
                'map': ap
            })