From abf2b2b84948f8e10100889f4f0a642548c67fbe Mon Sep 17 00:00:00 2001 From: Marius Arvinte <5852612+mariusarvinte@users.noreply.github.com> Date: Sat, 18 Mar 2023 23:06:34 -0700 Subject: [PATCH] Major clean-up and simplification --- hyperparam_tuning.py | 203 ------------------------------ inference.py | 250 ------------------------------------- loaders.py | 8 +- ncsnv2/losses/dsm.py | 8 +- test_score.py | 195 +++++++++++++++++++++++++++++ train.py => train_score.py | 125 +++++++------------ tune_hparams_score.py | 185 +++++++++++++++++++++++++++ 7 files changed, 434 insertions(+), 540 deletions(-) delete mode 100644 hyperparam_tuning.py delete mode 100644 inference.py create mode 100644 test_score.py rename train.py => train_score.py (66%) create mode 100644 tune_hparams_score.py diff --git a/hyperparam_tuning.py b/hyperparam_tuning.py deleted file mode 100644 index 9da07d4..0000000 --- a/hyperparam_tuning.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import numpy as np -import torch, sys, os, itertools, copy, argparse -sys.path.append('./') - -from tqdm import tqdm as tqdm -from ncsnv2.models.ncsnv2 import NCSNv2Deepest - -from loaders import Channels -from torch.utils.data import DataLoader -from matplotlib import pyplot as plt - -from dotmap import DotMap - -# Args -parser = argparse.ArgumentParser() -parser.add_argument('--gpu', type=int, default=2) -parser.add_argument('--spacing', nargs='+', type=float, default=[0.1]) -parser.add_argument('--pilot_alpha', nargs='+', type=float, default=[0.2]) -args = parser.parse_args() - -# Always !!! -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False -# Sometimes -torch.backends.cudnn.benchmark = True - -# GPU -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"; -os.environ["CUDA_VISIBLE_DEVICES"] = "0"; - -dc_boost = 1. -score_boost = 1. -noise_boost = 0.1 -# device = torch.device('cpu') -device = torch.device('cuda:0') - -# Target weights -# target_weights = './\ -# models_oct12_VarScaling/sigmaT39.1/intermediate_model.pt' -target_weights = './models_oct14/\ -numLambdas2_lambdaMin0.5_lambdaMax0.5_sigmaT39.1/final_model.pt' -contents = torch.load(target_weights, map_location=device) -# Extract config -config = contents['config'] -config.sampling.sigma = 0. # Nothing here -config.device = device -# Get a model -diffuser = NCSNv2Deepest(config) -diffuser = diffuser.to(device) -# !!! Load weights -diffuser.load_state_dict(contents['model_state']) -diffuser.eval() - -# Universal seeds -train_seed, val_seed = 1234, 4321 -# Get training config -# config.data.spacing_list = [0.1] -dataset = Channels(train_seed, config, norm=config.data.norm_channels) - -# Choose the core step size (epsilon) -config.sampling.steps_each = 3 -candidate_steps = np.logspace(-11, -7, 10000) -step_criterion = np.zeros((len(candidate_steps))) -gamma_rate = 1 / config.model.sigma_rate -for idx, step in enumerate(candidate_steps): - sigma_squared = config.model.sigma_end ** 2 - one_minus_ratio = (1 - step / sigma_squared) ** 2 - big_ratio = 2 * step /\ - (sigma_squared - sigma_squared * one_minus_ratio) - - # Criterion - step_criterion[idx] = one_minus_ratio ** config.sampling.steps_each * \ - (gamma_rate ** 2 - big_ratio) + big_ratio - -best_idx = np.argmin(np.abs(step_criterion - 1.)) -fixed_step_size = candidate_steps[best_idx] - -# Range of SNR, test channels and hyper-parameters -snr_range = np.arange(-10, 17.5, 2.5) -step_factor_range = np.asarray([0.0003, 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.]) # Multiplicative -spacing_range = np.asarray(args.spacing) # From a pre-defined index -pilot_alpha_range = np.asarray(args.pilot_alpha) -noise_range = 10 ** (-snr_range / 10.) -assert len(pilot_alpha_range) == 1, 'Too many pilot alphas for files!' - -# Global results -oracle_log = np.zeros((len(spacing_range), len(pilot_alpha_range), - len(step_factor_range), len(snr_range), - int(config.model.num_classes * \ - config.sampling.steps_each), 100)) # Should match data -result_dir = 'results_tuning' -if not os.path.isdir(result_dir): - os.makedirs(result_dir) - -# Wrap sparsity, steps and spacings -meta_params = itertools.product(spacing_range, pilot_alpha_range, step_factor_range) - -# For each hyper-combo -for meta_idx, (spacing, pilot_alpha, step_factor) in tqdm(enumerate(meta_params)): - # Unwrap indices - spacing_idx, pilot_alpha_idx, step_factor_idx = np.unravel_index( - meta_idx, (len(spacing_range), len(pilot_alpha_range), len(step_factor_range))) - - # Get a validation dataset and adjust parameters - val_config = copy.deepcopy(config) - val_config.data.spacing_list = [spacing] - val_config.mode.step_size = fixed_step_size * step_factor - val_config.data.num_pilots = int(np.floor(config.data.num_pilots * pilot_alpha)) - - # Create locals - val_dataset = Channels(val_seed, val_config, norm=[dataset.mean, dataset.std]) - val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), - shuffle=False, num_workers=0, drop_last=True) - val_iter = iter(val_loader) # For validation - - # Always the same initial points and data for validation - val_sample = next(val_iter) - _, val_P, _ = \ - val_sample['H'].to(device), val_sample['P'].to(device), \ - val_sample['Y'].to(device) - # Transpose pilots - val_P = torch.conj(torch.transpose(val_P, -1, -2)) - val_H_herm = val_sample['H_herm'].to(device) - val_H = val_H_herm[:, 0] + 1j * val_H_herm[:, 1] - # Initial value and measurements - init_val_H = torch.randn_like(val_H) - - # For each SNR value - for snr_idx, local_noise in tqdm(enumerate(noise_range)): - val_Y = torch.matmul(val_P, val_H) - val_Y = val_Y + \ - np.sqrt(local_noise) / np.sqrt(2.) * torch.randn_like(val_Y) - current = init_val_H.clone() - y = val_Y - forward = val_P - forward_h = torch.conj(torch.transpose(val_P, -1, -2)) - norm = [0., 1.] - oracle = val_H - - # Stop the count! - trailing_idx = 0 - # For each SNR point - with torch.no_grad(): - for step_idx in tqdm(range(val_config.model.num_classes)): - # Compute current step size and noise power - current_sigma = diffuser.sigmas[step_idx].item() - # Labels for diffusion model - labels = torch.ones(init_val_H.shape[0]).to(device) * step_idx - labels = labels.long() - - # For each step spent at that noise level - for inner_idx in range(val_config.sampling.steps_each): - # Compute score - current_real = torch.view_as_real(current).permute(0, 3, 1, 2) - # Get score - score = diffuser(current_real, labels) - # View as complex - score = \ - torch.view_as_complex(score.permute(0, 2, 3, 1).contiguous()) - - # Get un-normalized version for measurements - current_unnorm = norm[1] * current - # Compute alpha - alpha = val_config.model.step_size * \ - (current_sigma / val_config.model.sigma_end) ** 2 - - # Compute gradient for measurements in un-normalized space - meas_grad = torch.matmul(forward_h, - torch.matmul(forward, current_unnorm) - y) - # Re-normalize gradient to match score model - meas_grad = meas_grad / norm[1] - - # Annealing noise - grad_noise = np.sqrt(2 * alpha * noise_boost) * torch.randn_like(current) - - # Apply update - current = current + \ - score_boost * alpha * score - \ - dc_boost * alpha / (local_noise/2. + current_sigma ** 2) * \ - meas_grad + grad_noise - - # Store loss - oracle_log[ - spacing_idx, pilot_alpha_idx, step_factor_idx, snr_idx, trailing_idx] = \ - (torch.sum(torch.square(torch.abs(current - oracle)), dim=(-1, -2))/\ - torch.sum(torch.square(torch.abs(oracle)), dim=(-1, -2))).cpu().numpy() - - # Increment count - trailing_idx = trailing_idx + 1 - - # Delete validation dataset - del val_dataset, val_loader - # torch.cuda.empty_cache() - -# Squeeze -oracle_log = np.squeeze(oracle_log) -# Plot average NMSE at best stopping point and step factor value -plt.figure(); plt.plot(snr_range, - 10*np.log10(np.min(np.mean(oracle_log, axis=-1), axis=(-1, -3)))); -plt.xlabel('SNR [dB]'); plt.ylabel('NMSE [dB]'); plt.grid(); plt.show() diff --git a/inference.py b/inference.py deleted file mode 100644 index c69cdcf..0000000 --- a/inference.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import numpy as np -import torch, sys, os, itertools, copy, argparse -sys.path.append('./') - -from tqdm import tqdm as tqdm -from ncsnv2.models.ncsnv2 import NCSNv2Deepest - -from loaders import Channels -from torch.utils.data import DataLoader - -# Args -parser = argparse.ArgumentParser() -parser.add_argument('--gpu', type=int, default=0) -parser.add_argument('--channel', type=str, default='CDL-D') -parser.add_argument('--save_channels', type=int, default=0) -parser.add_argument('--spacing', nargs='+', type=float, default=[0.5]) -parser.add_argument('--pilot_alpha', nargs='+', type=float, default=[0.6]) -parser.add_argument('--noise_boost', type=float, default=0.001) -args = parser.parse_args() - -# Always !!! -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False -# Sometimes -torch.backends.cudnn.benchmark = True - -# GPU -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"; -os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu); - -# Target weights - replace with target model -target_weights = './models/\ -numLambdas2_lambdaMin0.1_lambdaMax0.5_sigmaT39.1/final_model.pt' -contents = torch.load(target_weights) -# Extract config -config = contents['config'] -config.sampling.sigma = 0. # Nothing here - -# !!! 'Beta' in paper -noise_boost = args.noise_boost - -# Get a model -diffuser = NCSNv2Deepest(config) -diffuser = diffuser.cuda() -# !!! Load weights -diffuser.load_state_dict(contents['model_state']) -diffuser.eval() - -# Universal seeds -train_seed, val_seed = 1234, 9999 -# Get training config -config.data.channel = 'CDL-D' -dataset = \ - Channels(train_seed, config, norm=config.data.norm_channels) - -# Choose the core step size (epsilon) according to [Song '20] -config.sampling.steps_each = 3 -candidate_steps = np.logspace(-11, -7, 10000) -step_criterion = np.zeros((len(candidate_steps))) -gamma_rate = 1 / config.model.sigma_rate -for idx, step in enumerate(candidate_steps): - sigma_squared = config.model.sigma_end ** 2 - one_minus_ratio = (1 - step / sigma_squared) ** 2 - big_ratio = 2 * step /\ - (sigma_squared - sigma_squared * one_minus_ratio) - - # Criterion - step_criterion[idx] = one_minus_ratio ** config.sampling.steps_each * \ - (gamma_rate ** 2 - big_ratio) + big_ratio - -best_idx = np.argmin(np.abs(step_criterion - 1.)) -fixed_step_size = candidate_steps[best_idx] - -# Range of SNR, test channels and hyper-parameters -snr_range = np.arange(-10, 17.5, 2.5) # np.arange(-10, 17.5, 2.5) -step_factor_range = np.asarray([1.]) # Multiplicative -spacing_range = np.asarray(args.spacing) # From a pre-defined index -pilot_alpha_range = np.asarray(args.pilot_alpha) -noise_range = 10 ** (-snr_range / 10.) - -# Save test results -if args.save_channels: - num_channels = 200 # !!! More - alpha_match = [0.6, 0.8, 1.0] - step_match = [1260, 1161, 1104] - - # For noise = 0.001 !!! - step_snr_match = [ - [ 429, 583, 864, 885, 1213, 1353, 1541, 1652, 1870, 2216, 2328], # Alpha = 0.6, CDL-D - [ 523, 612, 790, 995, 1122, 1437, 1538, 1843, 2028, 2246, 2437], # Alpha = 0.8, CDL-D - [ 525, 687, 816, 1000, 1137, 1435, 1623, 1765, 1938, 2141, 2270], # Alpha = 1.0 - ] - - print('Saving test results!') - saved_H = np.zeros((len(pilot_alpha_range), - len(snr_range), num_channels, 64, 16), - dtype=np.complex64) -else: - num_channels = 100 # Validation -# Global results -oracle_log = np.zeros((len(spacing_range), len(pilot_alpha_range), - len(step_factor_range), len(snr_range), - int(config.model.num_classes * \ - config.sampling.steps_each), num_channels)) # Should match data -result_dir = 'results_seed%d' % val_seed -if not os.path.isdir(result_dir): - os.makedirs(result_dir) - -# Wrap sparsity, steps and spacings -meta_params = itertools.product(spacing_range, pilot_alpha_range, step_factor_range) - -# For each hyper-combo -for meta_idx, (spacing, pilot_alpha, step_factor) in tqdm(enumerate(meta_params)): - # Unwrap indices - spacing_idx, pilot_alpha_idx, step_factor_idx = np.unravel_index( - meta_idx, (len(spacing_range), len(pilot_alpha_range), - len(step_factor_range))) - - # Get a validation dataset and adjust parameters - val_config = copy.deepcopy(config) - val_config.data.channel = args.channel - val_config.data.spacing_list = [spacing] - val_config.mode.step_size = fixed_step_size * step_factor - val_config.data.num_pilots = int(np.floor(config.data.num_pilots * pilot_alpha)) - - # Create locals - val_dataset = Channels(val_seed, val_config, norm=[dataset.mean, dataset.std]) - val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), - shuffle=False, num_workers=0, drop_last=True) - val_iter = iter(val_loader) # For validation - print('There are %d validation channels!' % len(val_dataset)) - - # Always the same initial points and data for validation - val_sample = next(val_iter) - _, val_P, _ = \ - val_sample['H'].cuda(), val_sample['P'].cuda(), val_sample['Y'].cuda() - # Transpose pilots - val_P = torch.conj(torch.transpose(val_P, -1, -2)) - val_H_herm = val_sample['H_herm'].cuda() - val_H = val_H_herm[:, 0] + 1j * val_H_herm[:, 1] - # Initial value and measurements - init_val_H = torch.randn_like(val_H) - - # Save oracle once - if args.save_channels: - oracle_H = val_H.cpu().numpy() - - # For each SNR value - for snr_idx, local_noise in tqdm(enumerate(noise_range)): - if args.save_channels: - # Find exact stopping point - target_stop = step_snr_match[int(np.where(pilot_alpha == alpha_match)[0])][snr_idx] - print('For this SNR, stopping at %d!' % target_stop) - - val_Y = torch.matmul(val_P, val_H) - val_Y = val_Y + \ - np.sqrt(local_noise) * torch.randn_like(val_Y) - current = init_val_H.clone() - y = val_Y - forward = val_P - forward_h = torch.conj(torch.transpose(val_P, -1, -2)) - norm = [0., 1.] - oracle = val_H - - # Stop the count! - trailing_idx = 0 - mark_break = False - # For each SNR point - with torch.no_grad(): - for step_idx in tqdm(range(val_config.model.num_classes)): - # Compute current step size and noise power - current_sigma = diffuser.sigmas[step_idx].item() - # Labels for diffusion model - labels = torch.ones(init_val_H.shape[0]).cuda() * step_idx - labels = labels.long() - - # For each step spent at that noise level - for inner_idx in range(val_config.sampling.steps_each): - # Compute score - current_real = torch.view_as_real(current).permute(0, 3, 1, 2) - # Get score - score = diffuser(current_real, labels) - # View as complex - score = \ - torch.view_as_complex(score.permute(0, 2, 3, 1).contiguous()) - - # Get un-normalized version for measurements - current_unnorm = norm[1] * current - # Compute alpha - alpha = val_config.model.step_size * \ - (current_sigma / val_config.model.sigma_end) ** 2 - - # Compute gradient for measurements in un-normalized space - meas_grad = torch.matmul(forward_h, - torch.matmul(forward, current_unnorm) - y) - # Re-normalize gradient to match score model - meas_grad = meas_grad / norm[1] - - # Annealing noise - grad_noise = np.sqrt(2 * alpha * noise_boost) * \ - torch.randn_like(current) - - # Apply update - current = current + \ - alpha * (score - meas_grad /\ - (local_noise/2. + current_sigma ** 2)) + grad_noise - - # Store loss - oracle_log[ - spacing_idx, pilot_alpha_idx, step_factor_idx, snr_idx, trailing_idx] = \ - (torch.sum(torch.square(torch.abs(current - oracle)), dim=(-1, -2))/\ - torch.sum(torch.square(torch.abs(oracle)), dim=(-1, -2))).cpu().numpy() - - # Decide to early stop if saving - if args.save_channels: - if trailing_idx == target_stop: - saved_H[pilot_alpha_idx, snr_idx] = \ - current.cpu().numpy() - # Full stop - mark_break = True - break - - # Increment count - trailing_idx = trailing_idx + 1 - - # Full stop - if args.save_channels and mark_break: - print('Early stopping at step %d!' % target_stop) - break - - # Delete validation dataset - del val_dataset, val_loader - torch.cuda.empty_cache() - -# Save results to file based on noise -save_dict = {'spacing_range': spacing_range, - 'pilot_alpha_range': pilot_alpha_range, - 'step_factor_range': step_factor_range, - 'snr_range': snr_range, - 'val_config': val_config, - 'oracle_log': oracle_log - } -if args.save_channels: - save_dict['saved_H'] = saved_H - save_dict['oracle_H'] = oracle_H -torch.save(save_dict, - result_dir + '/%s_noise%.1e.pt' % (args.channel, noise_boost)) diff --git a/loaders.py b/loaders.py index 1977b85..d8b3d47 100644 --- a/loaders.py +++ b/loaders.py @@ -29,12 +29,8 @@ def __init__(self, seed, config, norm=None): contents = hdf5storage.loadmat(filename) channels = np.asarray(contents['output_h'], dtype=np.complex64) - # Use only first subcarrier of the first symbol - if config.data.mixed_channels: - self.channels.append(channels.reshape( - -1, channels.shape[-2], channels.shape[-1])) - else: - self.channels.append(channels[:, 0]) + # Use only first subcarrier of each symbol + self.channels.append(channels[:, 0]) # Convert to array self.channels = np.asarray(self.channels) diff --git a/ncsnv2/losses/dsm.py b/ncsnv2/losses/dsm.py index fb06451..6268867 100644 --- a/ncsnv2/losses/dsm.py +++ b/ncsnv2/losses/dsm.py @@ -1,7 +1,10 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + import torch def anneal_dsm_score_estimation(scorenet, samples, sigmas, - labels=None, anneal_power=2., hook=None): + labels=None, anneal_power=2.): # This always enters during training if labels is None: # Randomly sample sigma @@ -27,7 +30,4 @@ def anneal_dsm_score_estimation(scorenet, samples, sigmas, loss = 1 / 2. * ((scores - target) ** 2).sum( dim=-1) * used_sigmas.squeeze() ** anneal_power - if hook is not None: - hook(loss, labels) - return loss.mean(dim=0) \ No newline at end of file diff --git a/test_score.py b/test_score.py new file mode 100644 index 0000000..276a144 --- /dev/null +++ b/test_score.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import torch, sys, os, itertools, copy, argparse +sys.path.append('./') + +from tqdm import tqdm as tqdm +from ncsnv2.models.ncsnv2 import NCSNv2Deepest +from loaders import Channels +from torch.utils.data import DataLoader +from matplotlib import pyplot as plt + +# Args +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', type=int, default=0) +parser.add_argument('--train', type=str, default='CDL-C') +parser.add_argument('--test', type=str, default='CDL-C') +parser.add_argument('--save_channels', type=int, default=0) +parser.add_argument('--spacing', nargs='+', type=float, default=[0.5]) +parser.add_argument('--pilot_alpha', nargs='+', type=float, default=[0.6]) +args = parser.parse_args() + +# Disable TF32 due to potential precision issues +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +torch.backends.cudnn.benchmark = True +# GPU +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + +# Target file +target_dir = './models/score/%s' % args.train +target_file = os.path.join(target_dir, 'final_model.pt') +contents = torch.load(target_file) +config = contents['config'] + +# Default hyper-parameters for pilot_alpha = 0.6, all SNR points +if args.train == 'CDL-A': + # !!! Not to be confused with 'pilot_alpha' that denotes fraction of pilots + alpha_step = 3e-11 # 'alpha' in paper Algorithm 1 + beta_noise = 0.01 # 'beta' in paper Algorithm 1 +elif args.train == 'CDL-B': + alpha_step = 3e-11 + beta_noise = 0.01 +elif args.train == 'CDL-C': + alpha_step = 3e-11 + beta_noise = 0.01 +elif args.train == 'CDL-D': + alpha_step = 3e-11 + beta_noise = 0.01 +elif args.train == 'Mixed': + alpha_step = 3e-11 + beta_noise = 0.01 + +# Instantiate model +diffuser = NCSNv2Deepest(config) +diffuser = diffuser.cuda() +# Load weights +diffuser.load_state_dict(contents['model_state']) +diffuser.eval() + +# Train and validation seeds +train_seed, val_seed = 1234, 4321 +# Get training dataset for normalization +config.data.channel = args.train +dataset = Channels(train_seed, config, norm=config.data.norm_channels) + +# Range of SNR, test channels and hyper-parameters +snr_range = np.arange(-10, 32.5, 2.5) +spacing_range = np.asarray(args.spacing) # From a pre-defined index +pilot_alpha_range = np.asarray(args.pilot_alpha) +noise_range = 10 ** (-snr_range / 10.) * config.data.image_size[1] +# Number of validation channels +num_channels = 100 + +# Global results +nmse_log = np.zeros((len(spacing_range), len(pilot_alpha_range), + len(snr_range), int(config.model.num_classes * \ + config.sampling.steps_each), num_channels)) +result_dir = './results/score/train-%s_test-%s' % ( + args.train, args.test) +os.makedirs(result_dir, exist_ok=True) + +# Wrap sparsity, steps and spacings +meta_params = itertools.product(spacing_range, pilot_alpha_range) + +# For each hyper-combo +for meta_idx, (spacing, pilot_alpha) in tqdm(enumerate(meta_params)): + # Unwrap indices + spacing_idx, pilot_alpha_idx = np.unravel_index( + meta_idx, (len(spacing_range), len(pilot_alpha_range))) + + # Get validation dataset + val_config = copy.deepcopy(config) + val_config.data.channel = args.test + val_config.data.spacing_list = [spacing] + val_config.data.num_pilots = int(np.floor(config.data.image_size[1] * pilot_alpha)) + val_dataset = Channels(val_seed, val_config, norm=[dataset.mean, dataset.std]) + val_loader = DataLoader(val_dataset, batch_size=num_channels, + shuffle=False, num_workers=0, drop_last=True) + val_iter = iter(val_loader) + print('There are %d validation channels' % len(val_dataset)) + + # Get all validation data explicitly + val_sample = next(val_iter) + _, val_P, _ = \ + val_sample['H'].cuda(), val_sample['P'].cuda(), val_sample['Y'].cuda() + # Transposed pilots + val_P = torch.conj(torch.transpose(val_P, -1, -2)) + val_H_herm = val_sample['H_herm'].cuda() + val_H = val_H_herm[:, 0] + 1j * val_H_herm[:, 1] + # Initial estimates + init_val_H = torch.randn_like(val_H) + + # For each SNR value + for snr_idx, local_noise in tqdm(enumerate(noise_range)): + val_Y = torch.matmul(val_P, val_H) + val_Y = val_Y + \ + np.sqrt(local_noise) * torch.randn_like(val_Y) + current = init_val_H.clone() + y = val_Y + forward = val_P + forward_h = torch.conj(torch.transpose(val_P, -1, -2)) + norm = [0., 1.] + oracle = val_H # Ground truth channels + # Count every step + trailing_idx = 0 + + for step_idx in tqdm(range(val_config.model.num_classes)): + # Compute current step size and noise power + current_sigma = diffuser.sigmas[step_idx].item() + # Labels for diffusion model + labels = torch.ones(init_val_H.shape[0]).cuda() * step_idx + labels = labels.long() + + # Compute annealed step size + alpha = alpha_step * \ + (current_sigma / val_config.model.sigma_end) ** 2 + + # For each step spent at that noise level + for inner_idx in range(val_config.sampling.steps_each): + # Compute score using real view of data + current_real = torch.view_as_real(current).permute(0, 3, 1, 2) + with torch.no_grad(): + score = diffuser(current_real, labels) + # View as complex + score = \ + torch.view_as_complex(score.permute(0, 2, 3, 1).contiguous()) + + # Compute gradient for measurements in un-normalized space + meas_grad = torch.matmul(forward_h, + torch.matmul(forward, current) - y) + # Sample noise + grad_noise = np.sqrt(2 * alpha * beta_noise) * \ + torch.randn_like(current) + + # Apply update + current = current + alpha * (score - meas_grad /\ + (local_noise/2. + current_sigma ** 2)) + grad_noise + + # Store loss + nmse_log[spacing_idx, pilot_alpha_idx, snr_idx, trailing_idx] = \ + (torch.sum(torch.square(torch.abs(current - oracle)), dim=(-1, -2))/\ + torch.sum(torch.square(torch.abs(oracle)), dim=(-1, -2))).cpu().numpy() + trailing_idx = trailing_idx + 1 + +# Use average estimation error to select best number of steps +avg_nmse = np.mean(nmse_log, axis=-1) +best_nmse = np.min(avg_nmse, axis=-1) + +# Plot results for all alpha values +plt.rcParams['font.size'] = 14 +plt.figure(figsize=(10, 10)) +for alpha_idx, local_alpha in enumerate(pilot_alpha_range): + plt.plot(snr_range, 10*np.log10(best_nmse[0, alpha_idx]), + linewidth=4, label='Alpha=%.2f' % local_alpha) +plt.grid(); plt.legend() +plt.title('Score-based channel estimation') +plt.xlabel('SNR [dB]'); plt.ylabel('NMSE [dB]') +plt.tight_layout() +plt.savefig(os.path.join(result_dir, 'results.png'), dpi=300, + bbox_inches='tight') +plt.close() + +# Save results to file based on noise +save_dict = {'nmse_log': nmse_log, + 'avg_nmse': avg_nmse, + 'best_nmse': best_nmse, + 'spacing_range': spacing_range, + 'pilot_alpha_range': pilot_alpha_range, + 'snr_range': snr_range, + 'val_config': val_config, + } +torch.save(save_dict, os.path.join(result_dir, 'results.pt')) \ No newline at end of file diff --git a/train.py b/train_score.py similarity index 66% rename from train.py rename to train_score.py index b3edadc..d9346e2 100644 --- a/train.py +++ b/train_score.py @@ -2,15 +2,10 @@ # -*- coding: utf-8 -*- import numpy as np -import torch, sys, os, copy +import torch, sys, os, copy, argparse sys.path.append('./') -from tqdm import tqdm as tqdm_base -def tqdm(*args, **kwargs): - if hasattr(tqdm_base, '_instances'): - for instance in list(tqdm_base._instances): - tqdm_base._decr_instances(instance) - return tqdm_base(*args, **kwargs) +from tqdm import tqdm as tqdm from ncsnv2.models import get_sigmas from ncsnv2.models.ema import EMAHelper @@ -20,16 +15,20 @@ def tqdm(*args, **kwargs): from loaders import Channels from torch.utils.data import DataLoader +from dotmap import DotMap -from dotmap import DotMap +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', type=int, default=0) +parser.add_argument('--train', type=str, default='CDL-C') +args = parser.parse_args() -# Always !!! +# Disable TF32 due to potential precision issues torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False - +torch.backends.cudnn.benchmark = True # GPU -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"; -os.environ["CUDA_VISIBLE_DEVICES"] = "0"; +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # Model config config = DotMap() @@ -54,30 +53,28 @@ def tqdm(*args, **kwargs): # Training config.training.batch_size = 32 config.training.num_workers = 4 -config.training.n_epochs = 800 +config.training.n_epochs = 400 config.training.anneal_power = 2 config.training.log_all_sigmas = False -config.training.eval_freq = 50 # In epochs # Data -config.data.channel = 'CDL-D' # Training and validation +config.data.channel = args.train config.data.channels = 2 # {Re, Im} -config.data.num_pilots = 64 -config.data.noise_std = 0.01 # 'Beta' in paper -config.data.image_size = [16, 64] # Channel size = Nr x Nt -config.data.mixed_channels = False -config.data.norm_channels = 'global' # Optional, no major impact +config.data.noise_std = 0 +config.data.image_size = [16, 64] # [Nt, Nr] for the transposed channel +config.data.num_pilots = config.data.image_size[1] +config.data.norm_channels = 'global' config.data.spacing_list = [0.5] # Training and validation -# Universal seeds +# Seeds for train and test datasets train_seed, val_seed = 1234, 4321 # Get datasets and loaders for channels dataset = Channels(train_seed, config, norm=config.data.norm_channels) dataloader = DataLoader(dataset, batch_size=config.training.batch_size, - shuffle=True, num_workers=config.training.num_workers, - drop_last=True) -# Create separate validation sets + shuffle=True, num_workers=config.training.num_workers, drop_last=True) + +# Validation data val_datasets, val_loaders, val_iters = [], [], [] for idx in range(len(config.data.spacing_list)): # Validation config @@ -97,26 +94,13 @@ def tqdm(*args, **kwargs): for idx in tqdm(range(len(dataset))): dist_matrix[idx] = np.linalg.norm( flat_channels[idx][None, :] - flat_channels, axis=-1) -# Pre-determined values -config.model.sigma_begin = 39.15 # !!! For CDL-D mixture -# config.model.sigma_begin = 27.77 # !!! For CDL-D lambda/2 - -# Apply Song's third recommendation -if False: - from scipy.stats import norm - candidate_gamma = np.logspace(np.log10(0.9), np.log10(0.99999), 1000) - gamma_criterion = np.zeros((len(candidate_gamma))) - for idx, gamma in enumerate(candidate_gamma): - gamma_criterion[idx] = \ - norm.cdf(np.sqrt(2 * np.prod(dataset[0]['H'].shape)) * (gamma - 1) + 3*gamma) - \ - norm.cdf(np.sqrt(2 * np.prod(dataset[0]['H'].shape)) * (gamma - 1) - 3*gamma) - best_idx = np.argmin(np.abs(gamma_criterion - 0.5)) -# Pre-determined -config.model.sigma_rate = 0.995 # !!! For everything -config.model.sigma_end = config.model.sigma_begin * \ +# Pre-determined values from 'Mixed' setting +config.model.sigma_begin = 39.15 +config.model.sigma_rate = 0.995 +config.model.sigma_end = config.model.sigma_begin * \ config.model.sigma_rate ** (config.model.num_classes - 1) -# Choose the step size (epsilon) according to [Song '20] +# Choose the inference step size (epsilon) according to [Song '20] candidate_steps = np.logspace(-13, -8, 1000) step_criterion = np.zeros((len(candidate_steps))) gamma_rate = 1 / config.model.sigma_rate @@ -130,63 +114,53 @@ def tqdm(*args, **kwargs): best_idx = np.argmin(np.abs(step_criterion - 1.)) config.model.step_size = candidate_steps[best_idx] -# Get a model +# Instantiate model diffuser = NCSNv2Deepest(config) diffuser = diffuser.cuda() -# Get optimizer + +# Instantiate optimizer optimizer = get_optimizer(config, diffuser.parameters()) -# Counter -start_epoch = 0 -step = 0 +# Instantiate counters and EMA helper +start_epoch, step = 0, 0 if config.model.ema: ema_helper = EMAHelper(mu=config.model.ema_rate) ema_helper.register(diffuser) -# Get a collection of sigma values +# Get all sigma values for the discretized VE-SDE sigmas = get_sigmas(config) -# Always the same initial points and data for validation +# Sample fixed validation data val_H_list = [] for idx in range(len(config.data.spacing_list)): val_sample = next(val_iters[idx]) val_H_list.append(val_sample['H_herm'].cuda()) -# More logging -config.log_path = 'models/\ -numLambdas%d_lambdaMin%.1f_lambdaMax%.1f_sigmaT%.1f' % ( - len(config.data.spacing_list), np.min(config.data.spacing_list), - np.max(config.data.spacing_list), config.model.sigma_begin) -if not os.path.exists(config.log_path): - os.makedirs(config.log_path) -# No sigma logging -hook = test_hook = None - -# Logged metrics +# Logging +config.log_path = './models/score/%s' % args.train +os.makedirs(config.log_path, exist_ok=True) train_loss, val_loss = [], [] -val_errors, val_epoch = [], [] +# For each epoch for epoch in tqdm(range(start_epoch, config.training.n_epochs)): + # For each batch for i, sample in tqdm(enumerate(dataloader)): - # Safety check diffuser.train() step += 1 - # Move data to device for key in sample: sample[key] = sample[key].cuda() - # Get loss on Hermitian channels + # Compute DSM loss using Hermitian channels loss = anneal_dsm_score_estimation( diffuser, sample['H_herm'], sigmas, None, - config.training.anneal_power, hook) + config.training.anneal_power) - # Keep a running loss + # Logging if step == 1: running_loss = loss.item() else: running_loss = 0.99 * running_loss + 0.01 * loss.item() - # Log train_loss.append(loss.item()) # Step @@ -213,8 +187,7 @@ def tqdm(*args, **kwargs): anneal_dsm_score_estimation( val_score, val_H_list[idx], sigmas, None, - config.training.anneal_power, - hook=test_hook) + config.training.anneal_power) # Store local_val_losses.append(val_dsm_loss.item()) # Sanity delete @@ -225,21 +198,19 @@ def tqdm(*args, **kwargs): # Print if len(local_val_losses) == 1: print('Epoch %d, Step %d, Train Loss (EMA) %.3f, \ - Val. Loss %.3f' % ( +Val. Loss %.3f' % ( epoch, step, running_loss, local_val_losses[0])) - elif len(local_val_losses) == 2: + elif len(local_val_losses) >= 2: print('Epoch %d, Step %d, Train Loss (EMA) %.3f, \ - Val. Loss (Split) %.3f %.3f' % ( +Val. Loss (Split) %.3f %.3f' % ( epoch, step, running_loss, local_val_losses[0], local_val_losses[1])) -# Save snapshot +# Save final weights torch.save({'model_state': diffuser.state_dict(), 'optim_state': optimizer.state_dict(), 'config': config, 'train_loss': train_loss, - 'val_loss': val_loss, - 'val_errors': val_errors, - 'val_epoch': val_epoch}, - os.path.join(config.log_path, 'final_model.pt')) + 'val_loss': val_loss}, + os.path.join(config.log_path, 'final_model.pt')) \ No newline at end of file diff --git a/tune_hparams_score.py b/tune_hparams_score.py new file mode 100644 index 0000000..be1f713 --- /dev/null +++ b/tune_hparams_score.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import torch, sys, os, itertools, copy, argparse +sys.path.append('./') + +from tqdm import tqdm as tqdm +from ncsnv2.models.ncsnv2 import NCSNv2Deepest + +from loaders import Channels +from torch.utils.data import DataLoader +from matplotlib import pyplot as plt + +# Args +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', type=int, default=0) +parser.add_argument('--channel', type=str, default='CDL-C') +parser.add_argument('--spacing', type=float, default=0.5) +parser.add_argument('--pilot_alpha', type=float, default=0.6) +args = parser.parse_args() + +# Disable TF32 due to potential precision issues +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +torch.backends.cudnn.benchmark = True +# GPU +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + +# Target file +target_dir = './models/score/%s' % args.channel +target_file = os.path.join(target_dir, 'final_model.pt') +contents = torch.load(target_file) +config = contents['config'] + +# Instantiate model +diffuser = NCSNv2Deepest(config) +diffuser = diffuser.cuda() +# Load weights +diffuser.load_state_dict(contents['model_state']) +diffuser.eval() + +# Train and validation seeds +train_seed, val_seed = 1234, 4321 +# Get training dataset for normalization +config.data.channel = args.channel +dataset = Channels(train_seed, config, norm=config.data.norm_channels) + +# Range of SNR, test channels and hyper-parameters +snr_range = np.arange(-10, 32.5, 2.5) +alpha_step_range = np.asarray([3e-11, 6e-11, 1e-10, 3e-10]) +beta_noise_range = np.asarray([0.1, 0.01, 0.001]) +noise_range = 10 ** (-snr_range / 10.) * config.data.image_size[1] + +# Global results +nmse_log = np.zeros((len(alpha_step_range), len(beta_noise_range), len(snr_range), + int(config.model.num_classes * config.sampling.steps_each), + 100)) # Should match data +result_dir = './results/score' +os.makedirs(result_dir, exist_ok=True) + +# Wrap hyper-parameters +meta_params = itertools.product(alpha_step_range, beta_noise_range) + +# For each hyper-combo +for meta_idx, (alpha_step, beta_noise) in tqdm(enumerate(meta_params)): + # Unwrap indices + alpha_idx, beta_idx = np.unravel_index( + meta_idx, (len(alpha_step_range), len(beta_noise_range))) + + # Get a validation dataset and adjust parameters + val_config = copy.deepcopy(config) + val_config.data.channel = args.channel + val_config.data.spacing_list = [args.spacing] + val_config.data.num_pilots = int(np.floor( + config.data.image_size[1] * args.pilot_alpha)) + val_dataset = Channels(val_seed, val_config, norm=[dataset.mean, dataset.std]) + val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), + shuffle=False, num_workers=0, drop_last=True) + val_iter = iter(val_loader) + print('There are %d validation channels' % len(val_dataset)) + + # Get all validation data explicitly + val_sample = next(val_iter) + _, val_P, _ = \ + val_sample['H'].cuda(), val_sample['P'].cuda(), val_sample['Y'].cuda() + # Transposed pilots + val_P = torch.conj(torch.transpose(val_P, -1, -2)) + val_H_herm = val_sample['H_herm'].cuda() + val_H = val_H_herm[:, 0] + 1j * val_H_herm[:, 1] + # Initial estimates + init_val_H = torch.randn_like(val_H) + + # For each SNR value + for snr_idx, local_noise in tqdm(enumerate(noise_range)): + val_Y = torch.matmul(val_P, val_H) + val_Y = val_Y + \ + np.sqrt(local_noise) * torch.randn_like(val_Y) + current = init_val_H.clone() + y = val_Y + forward = val_P + forward_h = torch.conj(torch.transpose(val_P, -1, -2)) + oracle = val_H # Ground truth channels + # Count every step + trailing_idx = 0 + + for step_idx in tqdm(range(val_config.model.num_classes)): + # Compute current step size and noise power + current_sigma = diffuser.sigmas[step_idx].item() + # Labels for diffusion model + labels = torch.ones(init_val_H.shape[0]).cuda() * step_idx + labels = labels.long() + + # Compute annealed step size + alpha = alpha_step * \ + (current_sigma / val_config.model.sigma_end) ** 2 + + # For each step spent at that noise level + for inner_idx in range(val_config.sampling.steps_each): + # Compute score using real view of data + current_real = torch.view_as_real(current).permute(0, 3, 1, 2) + with torch.no_grad(): + score = diffuser(current_real, labels) + # View as complex + score = \ + torch.view_as_complex(score.permute(0, 2, 3, 1).contiguous()) + + # Compute gradient for measurements in un-normalized space + meas_grad = torch.matmul(forward_h, + torch.matmul(forward, current) - y) + # Sample noise + grad_noise = np.sqrt(2 * alpha * beta_noise) * \ + torch.randn_like(current) + + # Apply update + current = current + alpha * (score - meas_grad /\ + (local_noise/2. + current_sigma ** 2)) + grad_noise + + # Store loss + nmse_log[alpha_idx, beta_idx, snr_idx, trailing_idx] = \ + (torch.sum(torch.square(torch.abs(current - oracle)), dim=(-1, -2))/\ + torch.sum(torch.square(torch.abs(oracle)), dim=(-1, -2))).cpu().numpy() + trailing_idx = trailing_idx + 1 + +# Average estimation error and best stopping point +avg_nmse = np.mean(nmse_log, axis=-1) +best_nmse = np.min(avg_nmse, axis=-1) + +# Find best hyper-parameters for each SNR point +best_alpha_snr, best_beta_snr = [], [] +for snr_idx in range(len(snr_range)): + local_nmse = best_nmse[..., snr_idx].flatten() + best_idx = np.argmin(local_nmse) + best_alpha_idx, best_beta_idx = np.unravel_index( + best_idx, (len(alpha_step_range), len(beta_noise_range))) + best_alpha_snr.append(alpha_step_range[best_alpha_idx]) + best_beta_snr.append(beta_noise_range[best_beta_idx]) + +# Plot all curves +plt.rcParams['font.size'] = 14 +plt.figure(figsize=(10, 10)) +for alpha_idx, local_alpha in enumerate(alpha_step_range): + for beta_idx, local_beta in enumerate(beta_noise_range): + plt.plot(snr_range, 10*np.log10(best_nmse[alpha_idx, beta_idx]), + linewidth=4, label='Alpha=%.2e, Beta=%.2e' % (local_alpha, local_beta)) +plt.grid(); plt.legend() +plt.title('Score-based hyperparameter search') +plt.xlabel('SNR [dB]'); plt.ylabel('NMSE [dB]') +plt.tight_layout() +plt.savefig(os.path.join(result_dir, '%s-hyperparameters.png' % args.channel), dpi=300, + bbox_inches='tight') +plt.close() + +# Save full results to file +torch.save({'nmse_log': nmse_log, + 'avg_nmse': avg_nmse, + 'best_nmse': best_nmse, + 'best_alpha_snr': best_alpha_snr, + 'best_beta_snr': best_beta_snr, + 'snr_range': snr_range, + 'alpha_step_range': alpha_step_range, + 'beta_noise_range': beta_noise_range, + 'config': config, 'args': args + }, os.path.join(result_dir, '%s-hyperparameters.pt' % args.channel)) \ No newline at end of file