from ast import arg
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import re
import os
import argparse
import random
import numpy as np
import time
import json
from loader import Loader2, Loader
from utils import setup_logger
from sklearn.cluster import KMeans
from custom_datasets import trans_dict, cls_dict, get_num_images
from models import models_dict
random.seed(0)

def get_args():
    parser = argparse.ArgumentParser(description='Active Learning Process')
    # Dataset-related argument
    parser.add_argument('--dataset', '-d', default='cifar10', type=str, help='Name of the dataset.')
    parser.add_argument('--datapath', default='DATAPATH', type=str, help='Path to the dataset.')

    # Model-related argument
    parser.add_argument('--net', '-n', default='vgg16', type=str, help='Name of the model.')

    # AL-related arguments
    parser.add_argument('--sorted_dataset_path', default='', type=str, help='path to the unlabeled_pool.txt generated by kmeans.py')
    parser.add_argument('--sort', default='high2low', choices=['low2high', 'high2low', 'uniform'], type=str, help='Sorting order.')
    parser.add_argument('--first', default='high1st', choices=['low1st', 'high1st', 'uni', 'rand'], type=str, help='First choice.')
    parser.add_argument('--sampling', '-s', type=str, required=True, default='confidence', choices=['confidence', 'entropy', 'loss', 'rand'], help='Types of sampling.')
    parser.add_argument('--cycles', default=None, type=int, help='Number of cycles.')
    parser.add_argument('--start_cycle', default=0, type=int, help='Starting cycle.')
    parser.add_argument('--beta', default=1.0, type=float, help='Balance factor.')
    parser.add_argument('--resume', '-r', default=None, type=str, help='Checkpoint path for resuming training.')
    parser.add_argument('--save', default=None, type=str, help='Save path for trained models.')
    parser.add_argument('--addendum', default=5000, type=int, help='Length of unlabeled pool for labeling.')
    
    # Training-related arguments
    parser.add_argument("--milestones", nargs='+', type=int, default=[30, 60, 90], help='List of epoch milestones.')
    parser.add_argument('--epochs', default=100, type=int, help='Number of training epochs.')
    parser.add_argument('--batch_size', default=128, type=int, help='Batch size for training.')
    parser.add_argument('--lr', default=0.1, type=float, help='Learning rate for training.')
    parser.add_argument('--per_samples_list', nargs='+', default=[10, 10, 10, 10, 10, 10, 10, 10, 10, 10], type=int, help='Percentage of labeled samples in each cycle.')
    parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for optimizer.')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for optimizer.')

    args = parser.parse_args()
    print(f'saving in {args.save}')
    return args

def train(models, criterion, optimizer, epoch, trainloader):
    models['backbone'].train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizers['backbone'].zero_grad()

        outputs =  models['backbone'](inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizers['backbone'].step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    logger.info('Train:[{}] Loss: {:.3f} | Acc: {:.3f}'.format(epoch, train_loss/(batch_idx+1), 100.*correct/total))

def test(models, criterion, epoch, cycle):
    global best_acc
    models['backbone'].eval()

    test_loss = 0
    correct = 0
    total = 0
    save_label = True if epoch == args.epochs - 1 else False 
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, feat = models['backbone'](inputs, is_feat=True)
            loss = criterion(outputs,  targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        logger.info('Test: [{}] Loss: {:.3f} | Acc: {:.3f}'.format(epoch, test_loss/(batch_idx+1), 100.*correct/total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        state = {
            'net': models['backbone'].state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        torch.save(state, os.path.join(args.save, 'main_{}.pth'.format(cycle)))
        logger.info('Saved main_{}.pth with acc={}!'.format(cycle, acc))
        best_acc = acc

def get_class_balanced(models, samples, cycle):
    '''class-balanced sampling (pseudo labeling)'''
    # dictionary with args.num_classes keys as class labels
    net = models['backbone']
    class_dict = {}
    [class_dict.setdefault(x,[]) for x in range(args.num_classes)]

    sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False,  transform=transform_test, path_list=samples)
    ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2)

    # overflow goes into remaining
    remaining = []
    net.eval()
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(ploader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            if len(class_dict[predicted.item()]) < 100: # add until len(class_dict) == 100
                class_dict[predicted.item()].append(samples[idx])
            else: # add the rest of them in remaining
                remaining.append(samples[idx])
            # progress_bar(idx, len(ploader))

    sample2k = []
    for items in class_dict.values():
        if len(items) == 100:
            sample2k.extend(items)
        else:
            # supplement samples from remaining 
            sample2k.extend(items)
            add = 100 - len(items)
            sample2k.extend(remaining[:add])
            remaining = remaining[add:]
    
    return sample2k

def get_confidence(models, samples, cycle):
    '''default: confidence sampling (pseudo labeling)
       return 1k samples w/ lowest top1 score'''
    net = models['backbone']
    sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False,  transform=transform_test, path_list=samples)
    ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2)

    top1_scores = []
    net.eval()
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(ploader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            scores, predicted = outputs.max(1)
            outputs = F.normalize(outputs, dim=1) # save top1 confidence score 
            probs = F.softmax(outputs, dim=1)
            top1_scores.append(probs[0][predicted.item()].cpu())
            
    idx = np.argsort(top1_scores)
    samples = np.array(samples)
    return samples[idx[:num_samples_list[cycle]]]

def get_entropy(models, samples, cycle):
    '''entropy sampling'''
    net = models['backbone']
    sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False,  transform=transform_test, path_list=samples)
    ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2)

    top1_scores = []
    net.eval()
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(ploader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            e = -1.0 * torch.sum(F.softmax(outputs, dim=1) * F.log_softmax(outputs, dim=1), dim=1)
            top1_scores.append(e.view(e.size(0)).cpu())

    idx = np.argsort(top1_scores)
    samples = np.array(samples)
    return samples[idx[-num_samples_list[cycle]:]]

def get_kmeans(models, samples, cycle):
    '''k means sampling'''
    net = models['backbone']
    sub5k = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=False,  transform=transform_test, path_list=samples)
    ploader = torch.utils.data.DataLoader(sub5k, batch_size=1, shuffle=False, num_workers=2)

    feats = []
    net.eval()
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(ploader):
            inputs, targets = inputs.to(device), targets.to(device)
            _, feat = net(inputs, is_feat=True)
            feats.append(np.array(torch.squeeze(feat).cpu()))
    
    k_means = KMeans(init='k-means++', n_clusters=num_samples_list[cycle], n_init=10).fit(feats)
    
    distances = k_means.transform(feats)  # 每个数据点到每个簇心的距离,维度-->(样本个数,类簇个数)
    center_idx = np.argmin(distances, axis=0) # 获取distances中每列最小的元素值索引
    samples = np.array(samples)

    return samples[center_idx]

def get_random(models, samples, cycle):
    '''random sampling'''
    idx = [random.randint(0, len(samples)-1) for _ in range(num_samples_list[cycle])]
    samples = np.array(samples)
    return samples[idx] 

def save(name, file):
    np.save(name, file)
    print(name + ' saved!')

if __name__ == '__main__':
    args = get_args()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0
    args.num_images = get_num_images(args.dataset)

    num_samples_list = [int(i*0.01*args.num_images) for i in args.per_samples_list]
    args.cycles = len(num_samples_list) if args.cycles is None else args.cycles
    print(f'num samples list: {num_samples_list}')

    # Data
    print('==> Preparing data..')
    assert args.dataset in cls_dict

    args.num_classes = cls_dict[args.dataset]
    transform_train, transform_test = trans_dict[args.dataset]

    testset = Loader(path=os.path.join(args.datapath, args.dataset), is_train=False,  transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

    # Model
    print('=> Loading net {}'.format(args.net))
    net = models_dict[args.net](num_classes=args.num_classes)

    models = {'backbone': net}
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

    labeled_path = os.path.join(args.save, 'labeled')
    os.makedirs(labeled_path, exist_ok=True)

    logger = setup_logger(name='main', output=args.save)
    logger.info(args)
    logger.info(f'==> Saving in {args.save}')

    sampling_dict = {
        'class_balanced': get_class_balanced,
        'confidence': get_confidence,
        'entropy': get_entropy,
        'rand':get_random, 
        'kmeans': get_kmeans,
    }
    get_sampling = sampling_dict[args.sampling]
    
    auto_resume = False
    if auto_resume:
        main_best_file = os.path.join(args.save, 'main_best.txt')
        if os.path.exists(main_best_file):
            with open(main_best_file) as f:
                lines = f.readlines()
            if len(lines) > 1:
                latest_cycle = int(lines[-1].split()[0])
                args.resume = os.path.join(args.save, f"main_{latest_cycle}.pth")
                args.start_cycle = latest_cycle + 1
                print(f'Auto resume from {args.resume}, start from {args.start_cycle}')
            
    print('==> Training..')
    start_time = time.time()
    labeled = []
    for cycle in range(args.start_cycle, args.cycles):
        criterion = criterion_test = nn.CrossEntropyLoss()
            
        optim_backbone = optim.SGD(models['backbone'].parameters(), lr=args.lr, 
                                momentum=args.momentum, weight_decay=args.weight_decay)
        sched_backbone = torch.optim.lr_scheduler.MultiStepLR(optim_backbone, milestones=args.milestones)
        
        optimizers = {'backbone': optim_backbone, }
        schedulers = {'backbone': sched_backbone, }

        best_acc = 0
        logger.info('Cycle: {}'.format(cycle))

        if cycle == args.start_cycle:
            if args.sorted_dataset_path.endswith('.npy'):
                sorted_dataset = np.load(args.sorted_dataset_path)
            elif args.sorted_dataset_path.endswith('.txt'):
                with open(os.path.join(args.sorted_dataset_path), 'r') as f:
                    sorted_dataset = [line.strip() for line in f.readlines()]
            else:
                raise ValueError(args.sorted_dataset_path)
        if cycle == 0:
            begin_idx = 0
            end_idx = int(args.beta*args.addendum)
            samples = sorted_dataset[:end_idx]
            logger.info(f'sample range: {0} -> {end_idx} unlabeled lenth: {len(samples)}')
        elif cycle == args.cycles - 1:
            begin_idx = int(args.num_images-args.beta*args.addendum)
            end_idx = args.num_images
            ori_samples = sorted_dataset[begin_idx:]
            samples = np.setdiff1d(ori_samples, labeled)
            logger.info(f'sample range: {begin_idx} -> {end_idx} lenth: {len(ori_samples)} unlabeled lenth: {len(samples)}')
        else:
            begin_idx = int((cycle+1)*args.addendum - (args.beta-1)/2*args.addendum)
            end_idx = int((cycle+2)*args.addendum + (args.beta-1)/2*args.addendum)
            ori_samples = sorted_dataset[begin_idx:end_idx]
            samples = np.setdiff1d(ori_samples, labeled)
            logger.info(f'sample range: {begin_idx} -> {end_idx} lenth: {len(ori_samples)} unlabeled lenth: {len(samples)}')

        if cycle > 0:
            print('>> Getting previous checkpoint')
            if args.start_cycle > 0 and cycle == args.start_cycle:
                assert os.path.exists(args.resume)
                checkpoint = torch.load(args.resume)
                labeled = np.load(os.path.join(labeled_path, f'labeled_{cycle}.pth'))
                best_acc = checkpoint['acc']
            else:
                checkpoint = torch.load(os.path.join(args.save, f'main_{cycle-1}.pth'))
            models['backbone'].load_state_dict(checkpoint['net'])
            print(f"cycle {cycle-1} best acc: {best_acc}")
            sample2k = get_sampling(models, samples, cycle)
        else:
            samples = np.array(samples)
            k = num_samples_list[cycle]
            if args.first == 'uni':
                sample2k = samples[[int(j*args.addendum/k) for j in range(k)]] if (args.large or args.small) else samples[[int(2*j*args.addendum/k) for j in range(k)]]
            elif args.first == 'high1st':
                sample2k = samples[:k]
            elif args.first == 'low1st':
                sample2k = samples[-k:]
            elif args.first == 'rand':
                sample2k = np.array(random.sample(list(samples), k))
            else: raise ValueError('Sort method {} is not supported!'.format(args.first))

        labeled.extend(sample2k)
        np.save(os.path.join(labeled_path, f'labeled_{cycle}.pth'), labeled)

        logger.info(f'>> Labeled length: {len(labeled)}')
        trainset = Loader2(path=os.path.join(args.datapath, args.dataset), is_train=True, transform=transform_train, path_list=labeled)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
        
        for epoch in range(args.epochs):
            train(models, criterion, optimizers, epoch, trainloader)
            test(models, criterion_test, epoch, cycle)
            present_time = time.time()
            epochs_before = epoch - args.start_epoch + 1 + cycle * args.epochs
            epochs_after = args.epochs - epoch - 1 + (args.cycles - cycle - 1) * args.epochs
            eta = (present_time - start_time) / epochs_before * epochs_after
            eta = time.strftime("%dd %H:%M:%S", time.gmtime(eta))
            print('Eta: {}'.format(eta))
            schedulers['backbone'].step()

        with open(os.path.join(args.save, 'main_best.txt'), 'a') as f:
            f.write(str(cycle) + ' ' + str(best_acc)+'\n')

    print('done')