# -*- coding: utf-8 -*-
"""
Created on Wed Dec 11 16:13:28 2019

@author: Arthur
"""
import os
import numpy as np
import mlflow
import os.path
import tempfile


from torch.utils.data import DataLoader, Subset
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torch.nn
import torch.nn.functional as F

# These imports are used to create the training datasets
from data.datasets import (DatasetWithTransform, DatasetTransformer,
                           RawDataFromXrDataset, ConcatDataset_,
                           Subset_, ComposeTransforms, MultipleTimeIndices)

# Some utils functions
from train.utils import (DEVICE_TYPE, learning_rates_from_string,
                         run_ids_from_string, list_from_string)
from data.utils import load_training_datasets, load_data_from_run
from testing.utils import create_test_dataset
from testing.metrics import MSEMetric, MaxMetric
from train.base import Trainer
import train.losses
import models.transforms

import argparse
import importlib
import pickle

from data.xrtransforms import SeasonalStdizer

import models.submodels
import sys

import copy

from utils import TaskInfo
from dask.diagnostics import ProgressBar

def negative_int(value: str):
    return -int(value)

def check_str_is_None(s: str):
    return None if s.lower() == 'none' else s

# PARAMETERS ---------
description = 'Trains a model on a chosen dataset from the store. Allows \
    to set training parameters via the CLI.'
parser = argparse.ArgumentParser(description=description)
parser.add_argument('exp_id', type=int,
                    help='Experiment id of the source dataset containing the '\
                    'training data.')
parser.add_argument('run_id', type=str,
                    help='Run id of the source dataset')
parser.add_argument('--batchsize', type=int, default=8)
parser.add_argument('--n_epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=learning_rates_from_string,
                    default={'0\1e-3'})
parser.add_argument('--train_split', type=float, default=0.8,
                    help='Between 0 and 1')
parser.add_argument('--test_split', type=float, default=0.8,
                    help='Between 0 and 1, greater than train_split.')
parser.add_argument('--time_indices', type=negative_int, nargs='*')
parser.add_argument('--printevery', type=int, default=20)
parser.add_argument('--weight_decay', type=float, default=0.05,
                    help="Depreciated. Controls the weight decay on the linear "
                         "layer")
parser.add_argument('--model_module_name', type=str, default='models.models1',
                    help='Name of the module containing the nn model')
parser.add_argument('--model_cls_name', type=str, default='FullyCNN',
                    help='Name of the class defining the nn model')
parser.add_argument('--loss_cls_name', type=str,
                    default='HeteroskedasticGaussianLossV2',
                    help='Name of the loss function used for training.')
parser.add_argument('--transformation_cls_name', type=str,
                    default='SquareTransform',
                    help='Name of the transformation applied to outputs ' \
                    'required to be positive. Should be defined in ' \
                    'models.transforms.')
parser.add_argument('--submodel', type=str, default='transform1')
parser.add_argument('--features_transform_cls_name', type=str, default='None',
                    help='Depreciated')
parser.add_argument('--targets_transform_cls_name', type=str, default='None',
                    help='Depreciated')
params = parser.parse_args()

# Log the experiment_id and run_id of the source dataset
mlflow.log_param('source.experiment_id', params.exp_id)
mlflow.log_param('source.run_id', params.run_id)

# Training parameters
# Note that we use two indices for the train/test split. This is because we
# want to avoid the time correlation to play in our favour during test.
batch_size = params.batchsize
learning_rates = params.learning_rate
weight_decay = params.weight_decay
n_epochs = params.n_epochs
train_split = params.train_split
test_split = params.test_split
model_module_name = params.model_module_name
model_cls_name = params.model_cls_name
loss_cls_name = params.loss_cls_name
transformation_cls_name = params.transformation_cls_name
# Transforms applied to the features and targets
temp = params.features_transform_cls_name
features_transform_cls_name = check_str_is_None(temp)
temp = params.targets_transform_cls_name
targets_transform_cls_name = check_str_is_None(temp)
# Submodel (for instance monthly means)
submodel = params.submodel


# Parameters specific to the input data
# past specifies the indices from the past that are used for prediction
indices = params.time_indices

# Other parameters
print_loss_every = params.printevery
model_name = 'trained_model.pth'

# Directories where temporary data will be saved
data_location = tempfile.mkdtemp(dir='/scratch/ag7531/temp/')
print('Created temporary dir at  ', data_location)

figures_directory = 'figures'
models_directory = 'models'
model_output_dir = 'model_output'


def _check_dir(dir_path):
    """Create the directory if it does not already exists"""
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)


_check_dir(os.path.join(data_location, figures_directory))
_check_dir(os.path.join(data_location, models_directory))
_check_dir(os.path.join(data_location, model_output_dir))


# Device selection. If available we use the GPU.
# TODO Allow CLI argument to select the GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_type = DEVICE_TYPE.GPU if torch.cuda.is_available() \
                              else DEVICE_TYPE.CPU
print('Selected device type: ', device_type.value)
# FIN PARAMETERS --------------------------------------------------------------


# DATA-------------------------------------------------------------------------
# Extract the run ids for the datasets to use in training
global_ds = load_data_from_run(params.run_id)
# Load data from the store, according to experiment id and run id
xr_datasets = load_training_datasets(global_ds, 'training_subdomains.yaml')
# Split into train and test datasets
datasets, train_datasets, test_datasets = list(), list(), list()


for xr_dataset in xr_datasets:
    # TODO this is a temporary fix to implement seasonal patterns
    submodel_transform = copy.deepcopy(getattr(models.submodels, submodel))
    print(submodel_transform)
    xr_dataset = submodel_transform.fit_transform(xr_dataset)
    with ProgressBar(), TaskInfo('Computing dataset'):
        xr_dataset = xr_dataset.compute()
    print(xr_dataset)
    dataset = RawDataFromXrDataset(xr_dataset)
    dataset.index = 'time'
    dataset.add_input('usurf')
    dataset.add_input('vsurf')
    dataset.add_output('S_x')
    dataset.add_output('S_y')
    # TODO temporary addition, should be made more general
    if submodel == 'transform2':
        dataset.add_output('S_x_d')
        dataset.add_output('S_y_d')
    train_index = int(train_split * len(dataset))
    test_index = int(test_split * len(dataset))
    features_transform = ComposeTransforms()
    targets_transform = ComposeTransforms()
    transform = DatasetTransformer(features_transform, targets_transform)
    dataset = DatasetWithTransform(dataset, transform)
    # dataset = MultipleTimeIndices(dataset)
    # dataset.time_indices = [0, ]
    train_dataset = Subset_(dataset, np.arange(train_index))
    test_dataset = Subset_(dataset, np.arange(test_index, len(dataset)))
    train_datasets.append(train_dataset)
    test_datasets.append(test_dataset)
    datasets.append(dataset)

# Concatenate datasets. This adds shape transforms to ensure that all regions
# produce fields of the same shape, hence should be called after saving
# the transformation so that when we're going to test on another region
# this does not occur.
train_dataset = ConcatDataset_(train_datasets)
test_dataset = ConcatDataset_(test_datasets)

# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, drop_last=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,
                             shuffle=False, drop_last=True)

print('Size of training data: {}'.format(len(train_dataset)))
print('Size of validation data : {}'.format(len(test_dataset)))
# FIN DATA---------------------------------------------------------------------


# NEURAL NETWORK---------------------------------------------------------------
# Load the loss class required in the script parameters
n_target_channels = datasets[0].n_targets
criterion = getattr(train.losses, loss_cls_name)(n_target_channels)

# Recover the model's class, based on the corresponding CLI parameters
try:
    models_module = importlib.import_module(model_module_name)
    model_cls = getattr(models_module, model_cls_name)
except ModuleNotFoundError as e:
    raise type(e)('Could not find the specified module for : ' +
                  str(e))
except AttributeError as e:
    raise type(e)('Could not find the specified model class: ' +
                  str(e))
net = model_cls(datasets[0].n_features, criterion.n_required_channels)
try:
    transformation_cls = getattr(models.transforms, transformation_cls_name)
    transformation = transformation_cls()
    transformation.indices = criterion.precision_indices
    net.final_transformation = transformation
except AttributeError as e:
    raise type(e)('Could not find the specified transformation class: ' +
                  str(e))

print('--------------------')
print(net)
print('--------------------')
print('***')


# Log the text representation of the net into a txt artifact
with open(os.path.join(data_location, models_directory,
                       'nn_architecture.txt'), 'w') as f:
    print('Writing neural net architecture into txt file.')
    f.write(str(net))
# FIN NEURAL NETWORK ---------------------------------------------------------

# Add transforms required by the model.
for dataset in datasets:
    dataset.add_transforms_from_model(net)


# Training---------------------------------------------------------------------
# Adam optimizer
# To GPU
net.to(device)

# Optimizer and learning rate scheduler
params = list(net.parameters())
optimizer = optim.Adam(params, lr=learning_rates[0], weight_decay=weight_decay)
lr_scheduler = MultiStepLR(optimizer, list(learning_rates.keys())[1:],
                           gamma=0.1)

trainer = Trainer(net, device)
trainer.criterion = criterion
trainer.print_loss_every = print_loss_every

# metrics saved independently of the training criterion.
metrics = {'R2': MSEMetric(), 'Inf Norm': MaxMetric()}
for metric_name, metric in metrics.items():
    metric.inv_transform = lambda x: test_dataset.inverse_transform_target(x)
    trainer.register_metric(metric_name, metric)

for i_epoch in range(n_epochs):
    print('Epoch number {}.'.format(i_epoch))
    # TODO remove clipping?
    train_loss = trainer.train_for_one_epoch(train_dataloader, optimizer,
                                             lr_scheduler, clip=1.)
    test = trainer.test(test_dataloader)
    if test == 'EARLY_STOPPING':
        print(test)
        break
    test_loss, metrics_results = test
    # Log the training loss
    print('Train loss for this epoch is ', train_loss)
    print('Test loss for this epoch is ', test_loss)

    for metric_name, metric_value in metrics_results.items():
        print('Test {} for this epoch is {}'.format(metric_name, metric_value))
    mlflow.log_metric('train loss', train_loss, i_epoch)
    mlflow.log_metric('test loss', test_loss, i_epoch)
    mlflow.log_metrics(metrics_results)
# Update the logged number of actual training epochs
mlflow.log_param('n_epochs_actual', i_epoch + 1)

# FIN TRAINING ----------------------------------------------------------------

# Save the trained model to disk
net.cpu()
full_path = os.path.join(data_location, models_directory, model_name)
torch.save(net.state_dict(), full_path)
net.cuda(device)

# Save other parts of the model
# TODO this should not be necessary
print('Saving other parts of the model')
full_path = os.path.join(data_location, models_directory, 'transformation')
with open(full_path, 'wb') as f:
    pickle.dump(transformation, f)

with TaskInfo('Saving trained model'):
    mlflow.log_artifact(os.path.join(data_location, models_directory))

# DEBUT TEST ------------------------------------------------------------------

for i_dataset, dataset, test_dataset, xr_dataset in zip(range(len(datasets)),
                                                        datasets,
                                                        test_datasets,
                                                        xr_datasets):
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size,
                                 shuffle=False, drop_last=True)
    output_dataset = create_test_dataset(net, criterion.n_required_channels,
                                         xr_dataset, test_dataset,
                                         test_dataloader, test_index, device)

    # Save model output on the test dataset
    output_dataset.to_zarr(os.path.join(data_location, model_output_dir,
                                        f'test_output{i_dataset}'))

# Log artifacts
print('Logging artifacts...')
mlflow.log_artifact(os.path.join(data_location, figures_directory))
mlflow.log_artifact(os.path.join(data_location, model_output_dir))
print('Done...')