import torch, torch.nn as nn, torch.nn.functional as F
from pdb import set_trace as stop
from tqdm import tqdm
from utils.util import get_class

from models import *
from utils.loss import MCBCELoss



def run_epoch(args, model, data, optimizer, epoch, desc, device, loss_weight=None, train=False,  warm=False, inference_with_sampling=False, stage='joint'):
    if train:
        model.train()
        if isinstance(model, (ProbCBM)):
            if warm and hasattr(model, 'cnn_module'):
                for p in model.cnn_module.parameters():
                    p.requires_grad = False
            elif hasattr(model, 'cnn_module'):
                for p in model.cnn_module.parameters():
                    p.requires_grad = True
        optimizer.zero_grad()
    else:
        model.eval()

    # pre-allocate full prediction and target tensors
    all_predictions = torch.zeros(len(data.dataset) * 2, args.num_labels).cpu()
    all_certainties = torch.zeros(len(data.dataset) * 2, args.num_concepts).cpu()
    all_cls_certainties = torch.zeros(len(data.dataset) * 2).cpu()
    all_targets = torch.zeros(len(data.dataset) * 2, args.num_labels).cpu()

    batch_idx = 0
    end_idx = 0
    loss_tot_dict = {'total': 0}

    # Set criterion for class and concept
    criterion_class = getattr(args, 'criterion_class', 'ce')
    if criterion_class == 'ce':
        criterion_class = nn.CrossEntropyLoss()
    else:
        raise ValueError('Got criterion_class', criterion_class)
    
    criterion_concept = getattr(args, 'criterion_concept', 'bce')
    if criterion_concept == 'bce':
        criterion_concept = nn.BCEWithLogitsLoss()
    elif criterion_concept == 'bce_prob':
        criterion_concept = nn.BCELoss()
    elif criterion_concept == 'MCBCELoss':
        in_criterion = nn.BCELoss(reduction='none')
        criterion_concept = get_class(criterion_concept, 'utils.loss')(criterion=in_criterion, reduction='mean', vib_beta=args.vib_beta, \
            group2concept=args.group2concept)
    else:
        raise ValueError('Got criterion_concept', criterion_concept)

    for batch in tqdm(data, mininterval=0.5, desc=desc, leave=False, ncols=50):

        images = batch['image'].float().to(device)
        target_class = batch['class_label'][:, 0].long().to(device)
        target_concept = batch['concept_label'].float().to(device)

        if train:
            preds_dict, losses_dict = model(images, target_concept=target_concept, target_class=target_class, T=args.n_samples_train, stage=stage)
        else:
            with torch.no_grad():
                preds_dict, losses_dict = model(images, target_concept=target_concept, target_class=target_class, inference_with_sampling=inference_with_sampling, T=args.n_samples_inference)

        B = images.shape[0]
        class_label_onehot, concept_labels, labels, concept_uncertainty, class_uncertainty = None, None, None, None, None
        if args.pred_class:
            class_labels = batch['class_label'].float()
            class_label_onehot = torch.zeros(class_labels.size(0), args.num_classes)
            class_label_onehot.scatter_(1, class_labels.long(), 1)
            labels = class_label_onehot

        concept_labels = batch['concept_label'].float()
        if args.pred_concept:
            if labels is not None:
                labels = torch.cat((concept_labels, labels), 1)
            else:
                labels = concept_labels
        assert (labels is not None)

        loss, pred = 0, None
        loss_iter_dict = {}
        if args.pred_concept:
            if isinstance(criterion_concept, MCBCELoss):
                pred_concept = preds_dict['pred_concept_prob']
                loss_concept, concept_loss_dict = criterion_concept(\
                    probs=preds_dict['pred_concept_prob'],
                    image_mean=preds_dict['pred_mean'], image_logsigma=preds_dict['pred_logsigma'],
                    concept_labels=target_concept, negative_scale=preds_dict['negative_scale'], shift=preds_dict['shift'])
                if 'pred_concept_uncertainty' in preds_dict.keys():
                    concept_uncertainty = preds_dict['pred_concept_uncertainty']
                for k, v in concept_loss_dict.items():
                    if k != 'loss':
                        loss_iter_dict['pcme_' + k] = v
            elif isinstance(criterion_concept, (nn.BCELoss)):
                pred_concept = preds_dict['pred_concept_prob']
                loss_concept = criterion_concept(pred_concept, target_concept)
                if 'pred_concept_uncertainty' in preds_dict.keys():
                    concept_uncertainty = preds_dict['pred_concept_uncertainty']
            else:
                pred_concept = preds_dict['pred_concept_logit']
                loss_concept = criterion_concept(pred_concept, target_concept)
                pred_concept = torch.sigmoid(pred_concept)
                if 'pred_concept_uncertainty' in preds_dict.keys():
                    concept_uncertainty = preds_dict['pred_concept_uncertainty']

            if stage != 'class':
                loss += loss_concept * loss_weight['concept']
            pred = pred_concept
            loss_iter_dict['concept'] = loss_concept

        if args.pred_class:
            if 'pred_class_logit' in preds_dict.keys():
                pred_class = preds_dict['pred_class_logit']
                loss_class = criterion_class(pred_class, target_class)
                pred_class = F.softmax(pred_class, dim=-1)
            else:
                assert 'pred_class_prob' in preds_dict.keys()
                pred_class = preds_dict['pred_class_prob']
                loss_class = F.nll_loss(pred_class.log(), target_class, reduction='mean')
            loss_iter_dict['class'] = loss_class

            if stage != 'concept':
                loss += loss_class * loss_weight['class']
            pred = pred_class if pred is None else torch.cat((pred_concept, pred_class), dim=1)

            if 'pred_class_uncertainty' in preds_dict.keys():
                class_uncertainty = preds_dict['pred_class_uncertainty']

        for k, v in losses_dict.items():
            loss_iter_dict[k] = v
            if k in loss_weight.keys() and loss_weight[k] != 0:
                loss += v * loss_weight[k]
        loss_out = loss

        for k, v in loss_iter_dict.items():
            if v != v:
                print(k, v)

        if train:
            loss_out.backward()
            # Grad Accumulation
            if ((batch_idx + 1) % args.grad_ac_steps == 0):
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_max_norm)
                optimizer.step()
                optimizer.zero_grad()

        ## Updates ##
        loss_tot_dict['total'] += loss_out.item()
        for k, v in loss_iter_dict.items():
            if k not in loss_tot_dict.keys():
                try:
                    loss_tot_dict[k] = v.item()
                except:
                    loss_tot_dict[k] = v
            else:
                try:
                    loss_tot_dict[k] += v.item()
                except:
                    loss_tot_dict[k] += v
        start_idx, end_idx = end_idx, end_idx + B

        if pred.size(0) != all_predictions[start_idx:end_idx].size(0):
            pred = pred.view(labels.size(0), -1)

        all_predictions[start_idx:end_idx] = pred.data.cpu()
        all_targets[start_idx:end_idx] = labels.data.cpu()
        if concept_uncertainty is not None:
            all_certainties[start_idx:end_idx] = concept_uncertainty.data.cpu()
        if class_uncertainty is not None:
            all_cls_certainties[start_idx:end_idx] = class_uncertainty.data.cpu()
        batch_idx += 1

    for k, v in loss_tot_dict.items():
        loss_tot_dict[k] = v / batch_idx


    return all_predictions[:end_idx], all_targets[:end_idx], all_certainties[:end_idx], all_cls_certainties[:end_idx], loss_tot_dict