import os
import time
from tqdm import tqdm
from datetime import datetime
import wandb

import torch
from torchvision.utils import make_grid, save_image

from utils import *
from utils.metrics import *
# from skimage.measure import compare_psnr
from data import getLoader

Train_test = False

class Generic_train_test():
	def __init__(self, opts, accelerator, net, optimizer, scheduler, train_loader, val_loaders, datasets_name, metrics):
		self.opts = opts
		self.net = net
		self.optimizer = optimizer
		self.scheduler = scheduler

		self.l1_loss = torch.nn.L1Loss()
		self.stop_l1 = opts['train']['stop_l1'] if 'stop_l1' in opts['train'] else opts['train']['epochs']
		self.ms_ssim = MS_SSIM(accelerator)
		# self.ssim = SSIM()
		self.loss_funs = opts['train']['loss_funs']
		self.loss_weights = opts['train']['loss_weights']

		self.train_loader = train_loader
		self.val_loaders = val_loaders
		self.datasets_name = datasets_name

		# metrics
		self.best_loss = metrics['val_loss']
		self.best_ssim = metrics['val_ssim']
		self.best_psnr = metrics['val_psnr']

		# dirs
		self.checkpoint_dir = opts['Experiment']['checkpoint_dir']
		self.result_dir = opts['Experiment']['result_dir']
		self.lambda_gray = 0.5

		self.sar_trans = opts['sar_trans'] if 'sar_trans' in opts else False
		self.use_id = opts['use_id'] if 'use_id' in opts else False
		self.change_dataset = opts['train']['change_dataset'] if 'change_dataset' in opts['train'] else None

	def decode_input(self, data):
		return data
		# raise NotImplementedError()

	def train(self, accelerator, run, start_epoch, end_epoch):
		if accelerator.is_local_main_process:
			wandb.watch(self.net)
			wandb.define_metric("epoch")
			wandb.define_metric("lr", step_metric="epoch")
			metrics = ['train_loss', 'train_ssim', 'train_psnr', 'val_loss', 'val_ssim', 'val_psnr']
			for metric in metrics:
				wandb.define_metric(metric, step_metric="epoch")

		accelerator.print(f"#Train dataset: {self.datasets_name[0]}")
		accelerator.print('#Train image nums: ', len(self.train_loader)*self.opts['datasets']['train']['batch_size'])
		for epoch in range(start_epoch+1, end_epoch+1):
			cureent_epoch = f'epoch_{epoch}'
			if self.change_dataset!=None and cureent_epoch in self.change_dataset.keys():
				change_opt = self.change_dataset[cureent_epoch]
				self.train_loader = accelerator.prepare(getLoader(change_opt))
				accelerator.print('#Change dataet and image nums: ',
								  len(self.train_loader) * self.opts['datasets']['train']['batch_size'])

			batch_time = AverageMeter('Time', ':6.3f')
			data_time = AverageMeter('Data', ':6.3f')
			m_l1_loss = AverageMeter('Loss', ':.4e')
			m_ssim = AverageMeter('SSIM', ':6.2f')
			m_psnr = AverageMeter('PSNR', ':6.2f')

			if accelerator.is_local_main_process:
				wandb.log({'lr': self.optimizer.param_groups[0]["lr"], 'epoch':epoch})
			self.net.train()
			end = time.time()
			with tqdm(total=len(self.train_loader), desc=f'[Epoch {epoch}/{end_epoch}]', unit='batch',
					  disable=not accelerator.is_local_main_process) as train_pbar:
				for step, batch in enumerate(self.train_loader):
					with accelerator.accumulate(self.net):
						data_time.update(time.time() - end)

						image = batch['opt_cloudy']
						sar = batch['sar']
						label = batch['opt_clear']
						if self.use_id:
							image_id = batch['image_id']
						self.optimizer.zero_grad()
						loss_all = 0
						if self.use_id:
							pred = self.net(image, sar, image_id, accelerator)
						elif self.sar_trans:
							pred = self.net(sar)
						else:
							pred = self.net(image, sar)

						if 'pixel' in self.loss_funs.keys() and epoch < self.stop_l1:
							loss_l1 = self.l1_loss(pred, label)
							loss_all += loss_l1 * self.loss_weights[0]
						if 'ssim' in self.loss_funs.keys():
							if self.loss_funs['ssim'] == 'ms_ssim':
								loss_ssim = 1 - self.ms_ssim(pred, label)
							else:
								loss_ssim = 1 - SSIM(pred, label)
							loss_all += loss_ssim * self.loss_weights[1]

						accelerator.backward(loss_all)

						self.optimizer.step()
						# self.scheduler.step()

						# metrics
						ssim = SSIM(pred, label).item() # ?
						psnr = PSNR(pred, label)

						# loss_v = torch.mean(accelerator.gather_for_metrics(loss)).item()
						m_l1_loss.update(loss_l1.item(), image.size(0)) # average ?
						m_ssim.update(ssim, image.size(0))
						m_psnr.update(psnr, image.size(0))

						batch_time.update(time.time() - end)
						end = time.time()
						if accelerator.is_local_main_process:
							# =========== visualize results ============#
							if step % self.opts['log_step_freq'] == 0:
								total_steps = len(self.train_loader) * (epoch-1) + step + 1
								wandb.log({'loss': m_l1_loss.avg, 'ssim': m_ssim.avg, 'psnr': m_psnr.avg, 'step':total_steps})

							if step % self.opts['visual_step_freq']==0 or step==len(self.train_loader)-1:
								# figure
								img_sample = torch.cat([image.data, pred.data, label.data], -1)  # 按宽拼接
								grid = make_grid(img_sample, nrow=1, normalize=True) # 每一行显示的图像列数
								save_image(grid, os.path.join(self.result_dir, 'train_images', f'img_epoch_{epoch}_step_{step}.png'))

							# 后缀信息
							train_pbar.set_postfix(ordered_dict={'loss': m_l1_loss.avg, 'ssim': m_ssim.avg, 'psnr': m_psnr.avg})
							train_pbar.update()

						if Train_test:
							break

			# if epoch > self.opts['train']['scheduler']['lr_start_epoch_decay'] - self.opts['train']['scheduler']['lr_step']:
			self.scheduler.step()

			if accelerator.is_local_main_process:
				wandb.log({'train_loss': m_l1_loss.avg, 'train_ssim': m_ssim.avg, 'train_psnr': m_psnr.avg})

			accelerator.wait_for_everyone()

			valid_epoch_freq = self.opts['valid_epoch_freq'] if 'valid_epoch_freq' in self.opts else 1
			if epoch==start_epoch+1 or (epoch % valid_epoch_freq == 0):
				val_loss, val_ssim, val_psnr = self.validate(epoch, accelerator, run)
				metrics_dict = {'val_loss':val_loss, 'val_ssim':val_ssim, 'val_psnr':val_psnr}

			checkpoint_dict = {'epoch': epoch, 'model': accelerator.unwrap_model(self.net).state_dict(),
							   'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.scheduler.state_dict(), 'metrics': metrics_dict}

			if epoch % self.opts['save_epoch_freq'] == 0:
				accelerator.save(checkpoint_dict, os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}.pth'))

			if epoch == start_epoch or (epoch % valid_epoch_freq == 0) or (end_epoch - epoch < 5):
				update_best = val_ssim > self.best_ssim
				if update_best:
					self.best_ssim = val_ssim
					accelerator.print(f'Best valid ssim {self.best_ssim} saved at epoch {epoch}')
					accelerator.save(checkpoint_dict, os.path.join(self.checkpoint_dir, f'checkpoint_best.pth'))
			# save last
			accelerator.save(checkpoint_dict, os.path.join(self.checkpoint_dir, f'checkpoint_last.pth'))

		if accelerator.is_local_main_process:
			wandb.finish()

	@torch.no_grad()
	def validate(self, epoch, accelerator, run):
		self.net.eval()
		batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
		m_l1_loss = AverageMeter('Loss', ':.4e', Summary.NONE)
		m_ssim = AverageMeter('SSIM', ':6.2f',  Summary.AVERAGE)
		m_psnr = AverageMeter('PSNR', ':6.2f',  Summary.AVERAGE)

		end = time.time()
		for idx, val_loader in enumerate(self.val_loaders):
			correct_count = 0
			# accelerator.print(f'Validation on dataset {datasets_name[idx + 1]}:')
			#
			# with tqdm(total=len(val_loader.dataset), desc=f'Val on {self.datasets_name[idx + 1]}', unit='img',
			with tqdm(total=len(val_loader), desc=f'Val on {self.datasets_name[idx + 1]}', unit='batch',
					  disable=not accelerator.is_local_main_process) as val_pbar:
				for step, batch in enumerate(val_loader):
					image = batch['opt_cloudy']
					sar = batch['sar']
					label = batch['opt_clear']
					if self.use_id:
						image_id = batch['image_id']

					if self.use_id:
						pred = self.net(image, sar, image_id, accelerator)
					elif self.sar_trans:
						pred = self.net(sar)
					else:
						pred = self.net(image, sar)

					# Gathers tensor and potentially drops duplicates in the last batch
					all_pred, all_label = accelerator.gather_for_metrics((pred, label))

					loss_l1 = self.l1_loss(all_pred, all_label)
					# metrics
					ssim = SSIM(all_pred, all_label).item()  # ?
					# ssim = self.ssim(all_pred, all_label).item()
					psnr = PSNR(all_pred, all_label)

					m_l1_loss.update(loss_l1.item() , image.size(0))  # average ?
					m_ssim.update(ssim, image.size(0))
					m_psnr.update(psnr, image.size(0))

					batch_time.update(time.time() - end)
					end = time.time()

					if accelerator.is_local_main_process:
						# figure
						if step==0 or step % self.opts['valid_visual_step_freq'] == 0:
							img_sample = torch.cat([image.data, pred.data, label.data], -1)  # 按宽拼接
							grid = make_grid(img_sample, nrow=1, normalize=True) # 每一行显示的图像列数
							save_image(grid, os.path.join(self.result_dir, 'valid_images', f'img_epoch_{epoch}_step_{step}.png'))

						# val_pbar.update(all_label.shape[0])
						val_pbar.set_postfix(
							ordered_dict={'loss': m_l1_loss.avg, 'ssim': m_ssim.avg, 'psnr': m_psnr.avg})
						val_pbar.update()

		accelerator.wait_for_everyone()
		if accelerator.is_local_main_process:
			wandb.log({'val_loss': m_l1_loss.avg, 'val_ssim': m_ssim.avg, 'val_psnr': m_psnr.avg})

		accelerator.print(f'val_loss: {m_l1_loss.avg}, val_ssim: {m_ssim.avg}, val_ssim: {m_psnr.avg}')
		return m_l1_loss.avg, m_ssim.avg, m_psnr.avg