diff --git a/src/datasets/wrappers/gpt2_commongen.py b/src/datasets/wrappers/gpt2_commongen.py index eaf5400..3ca2a35 100644 --- a/src/datasets/wrappers/gpt2_commongen.py +++ b/src/datasets/wrappers/gpt2_commongen.py @@ -10,7 +10,7 @@ def load_gpt2_commongen( path: str = 'datasets', seed: int = 42, - num_splits: int = 20 + num_splits: int = 40 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: splits = glob.glob(os.path.join(path, 'gpt2_commongen', 'common-gen.train.*')) assert len(splits) > 0, f"There are no GPT2 CommonGen-tuned generated splits: {path} {splits}" diff --git a/src/graphics/distributions.py b/src/graphics/distributions.py index f6faa73..9128e86 100644 --- a/src/graphics/distributions.py +++ b/src/graphics/distributions.py @@ -2,6 +2,7 @@ import numpy as np import torch +from torch.utils.data import DataLoader, TensorDataset from sklearn.neighbors import KernelDensity from matplotlib import pyplot as plt @@ -16,7 +17,7 @@ def plot_bivariate_samples_hmap( xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, zm: float = 0.0, - nbins: int = 600 + nbins: int = 256 ) -> np.ndarray: setup_tueplots(1, 1, hw_ratio=1.0) if xlim is None: @@ -86,7 +87,7 @@ def kde_samples_hmap( xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, zm: float = 0.0, - nbins: int = 600, + nbins: int = 256, *, bandwidth: float = 0.2 ) -> np.ndarray: @@ -113,23 +114,47 @@ def bivariate_pdf_heatmap( xlim: Tuple[float, float], ylim: Tuple[float, float], zm: float = 0.0, - nbins: int = 600, - device: Optional[Union[str, torch.device]] = None + nbins: int = 256, + batch_size: Optional[int] = None, + variables: Optional[Tuple[int, int]] = None, + device: Optional[Union[str, torch.device]] = None, ) -> np.ndarray: + if device is None: + device = 'cpu' zm_xamount = np.abs(xlim[1] - xlim[0]) zm_yamount = np.abs(ylim[1] - ylim[0]) xlim = (xlim[0] - zm * zm_xamount), (xlim[1] + zm * zm_xamount) ylim = (ylim[0] - zm * zm_yamount), (ylim[1] + zm * zm_yamount) xi, yi = np.mgrid[xlim[0]:xlim[1]:nbins * 1j, ylim[0]:ylim[1]:nbins * 1j] - xy = np.stack([xi.flatten(), yi.flatten()], axis=1)\ + xy = np.stack([xi.flatten(), yi.flatten()], axis=1) \ .astype(retrieve_default_dtype(numpy=True), copy=False) - if device is None: - device = 'cpu' - xy = torch.from_numpy(xy).to(device) - if isinstance(model, PC): - zi = model.log_prob(xy) + if model.num_variables < 1: + raise ValueError("A PC defined over two variables is required") + elif model.num_variables == 2: + xy = torch.from_numpy(xy).to(device) + if isinstance(model, PC): + zi = model.log_prob(xy) + else: + zi = model().log_prob(xy) else: - zi = model().log_prob(xy) + if batch_size is None: + batch_size = 128 + if variables is None or len(variables) != 2: + raise ValueError("The two variables to show the PDF of cannot be None") + # Perform variable marginalization + points = TensorDataset(xy[0], xy[1]) + points = DataLoader(points, batch_size=batch_size, shuffle=False, drop_last=False) + mar_mask = torch.ones(1, model.num_variables, dtype=torch.bool, device=device) + mar_mask[:, variables] = False + zi = list() + for xf, yf in points: + samples = torch.zeros(size=(xf.shape[0], model.num_variables), dtype=xf.dtype) + samples[:, variables[0]] = xf + samples[:, variables[1]] = yf + samples.to(device) + log_probs = model.log_marginal_prob(samples, mar_mask) + zi.append(log_probs) + zi = torch.concatenate(zi, dim=0) zi = torch.exp(zi).cpu().numpy().reshape(xi.shape) return zi diff --git a/src/graphics/utils.py b/src/graphics/utils.py index f3af4c0..65aaa8c 100644 --- a/src/graphics/utils.py +++ b/src/graphics/utils.py @@ -31,12 +31,20 @@ def setup_tueplots( #}) -def array_to_image(array: np.ndarray, vmin: Optional[float] = None, vmax: Optional[float] = None) -> np.ndarray: +def array_to_image( + array: np.ndarray, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + colorbar: bool = False +) -> np.ndarray: assert len(array.shape) == 2 xi, yi = np.mgrid[range(array.shape[0]), range(array.shape[1])] setup_tueplots(1, 1, hw_ratio=1.0) fig, ax = plt.subplots() - ax.pcolormesh(xi, yi, array, shading='auto', vmin=vmin, vmax=vmax) + cmap = 'turbo' if colorbar else 'jet' + pcm = ax.pcolormesh(xi, yi, array, shading='auto', cmap=cmap, vmin=vmin, vmax=vmax) + if colorbar: + plt.colorbar(pcm) ax.set_xticks([]) ax.set_yticks([]) return matplotlib_buffer_to_image(fig) diff --git a/src/pcs/hmm.py b/src/pcs/hmm.py index 2595c4b..6d17a94 100644 --- a/src/pcs/hmm.py +++ b/src/pcs/hmm.py @@ -3,6 +3,7 @@ import abc from typing import Tuple, Union, Optional +import numpy as np import torch from torch import nn @@ -95,12 +96,14 @@ def __init__( seq_length: int, hidden_size: int = 2, init_method: str = 'normal', - init_scale: float = 1.0 + init_scale: float = 1.0, + l2norm: bool = False ): assert seq_length > 1 super().__init__(num_variables=seq_length) self.vocab_size = vocab_size self.hidden_size = hidden_size + self.l2norm = l2norm latent_prior = torch.empty(self.hidden_size) init_params_(latent_prior, init_method, init_scale=init_scale) @@ -119,17 +122,21 @@ def eval_log_pf(self) -> Tuple[Optional[Union[torch.Tensor, Tuple[torch.Tensor, return None, log_pf def _latent_prior(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + w = self.latent_prior + if self.l2norm: + w = w / torch.linalg.vector_norm(w, ord=2, dim=0, keepdim=True) + if square: # x: (batch_size, hidden_size, hidden_size) # self.latent_prior: (hidden_size,) m_x, _ = torch.max(x, dim=2, keepdim=True) # (batch_size, hidden_size, 1) x = x_si * torch.exp(x - m_x) - x = torch.sum(self.latent_prior * x, dim=2) + x = torch.einsum('bij,j->bi', x, w) x_si = torch.sign(x.detach()) x = m_x.squeeze(dim=2) + safelog(torch.abs(x)) # (batch_size, hidden_size) m_x, _ = torch.max(x, dim=1, keepdim=True) # (batch_size, 1) x = x_si * torch.exp(x - m_x) - x = torch.sum(self.latent_prior * x, dim=1, keepdim=True) + x = torch.mm(x, w.unsqueeze(dim=1)) x_si = torch.sign(x.detach()) x = m_x + safelog(torch.abs(x)) # (batch_size, 1) return x, x_si @@ -137,23 +144,27 @@ def _latent_prior(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = Fals # self.latent_prior: (hidden_size,) m_x, _ = torch.max(x, dim=1, keepdim=True) x = x_si * torch.exp(x - m_x) - y = torch.sum(self.latent_prior * x, dim=1, keepdim=True) + y = torch.mm(x, w.unsqueeze(dim=1)) y_si = torch.sign(y.detach()) y = safelog(torch.abs(y)) + m_x return y, y_si def _latent_conds(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + w = self.latent_conds + if self.l2norm: + w = w / torch.linalg.vector_norm(w, ord=2, dim=1, keepdim=True) + if square: # x: (batch_size, hidden_size, hidden_size) # self.latent_conds: (hidden_size, hidden_size) m_x, _ = torch.max(x, dim=2, keepdim=True) # (batch_size, hidden_size, 1) x = x_si * torch.exp(x - m_x) - x = torch.einsum('pi,bji->bpj', self.latent_conds, x) + x = torch.einsum('pi,bji->bpj', w, x) x_si = torch.sign(x.detach()) x = m_x.permute(0, 2, 1) + safelog(torch.abs(x)) # (batch_size, hidden_size, hidden_size) m_x, _ = torch.max(x, dim=2, keepdim=True) # (batch_size, hidden_size, 1) x = x_si * torch.exp(x - m_x) - x = torch.einsum('qj,bpj->bpq', self.latent_conds, x) + x = torch.einsum('qj,bpj->bpq', w, x) x_si = torch.sign(x.detach()) x = m_x + safelog(torch.abs(x)) # (batch_size, hidden_size, hidden_size) return x, x_si @@ -161,21 +172,29 @@ def _latent_conds(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = Fals # self.latent_conds: (hidden_size, hidden_size) m_x, _ = torch.max(x, dim=1, keepdim=True) x = x_si * torch.exp(x - m_x) - y = torch.einsum('ij,bj->bi', self.latent_conds, x) + y = torch.einsum('ij,bj->bi', w, x) y_si = torch.sign(y.detach()) y = safelog(torch.abs(y)) + m_x return y, y_si def _emission_conds(self, x: torch.Tensor, i: int) -> Tuple[torch.Tensor, torch.Tensor]: + e = self.emission_conds + if self.l2norm: + e = e / torch.linalg.vector_norm(e, ord=2, dim=0, keepdim=True) + zi = torch.arange(self.hidden_size, device=x.device).unsqueeze(dim=0) - w = self.emission_conds[zi, x[:, i].unsqueeze(dim=-1)] + w = e[zi, x[:, i].unsqueeze(dim=-1)] w_si = torch.sign(w.detach()) w = safelog(torch.abs(w)) return w, w_si def _emission_conds_normalize(self) -> Tuple[torch.Tensor, torch.Tensor]: - w_si = torch.sign(self.emission_conds.detach()) - w = safelog(torch.abs(self.emission_conds)) + e = self.emission_conds + if self.l2norm: + e = e / torch.linalg.vector_norm(e, ord=2, dim=0, keepdim=True) + + w_si = torch.sign(e.detach()) + w = safelog(torch.abs(e)) m_w, _ = torch.max(w, dim=1, keepdim=True) e_w = w_si * torch.exp(w - m_w) z = torch.mm(e_w, e_w.T) diff --git a/src/pcs/initializers.py b/src/pcs/initializers.py index 5cafb8d..725295d 100644 --- a/src/pcs/initializers.py +++ b/src/pcs/initializers.py @@ -16,7 +16,8 @@ 'gamma', 'dirichlet', 'xavier-uniform', - 'xavier-normal' + 'xavier-normal', + 'stiefel' ] @@ -55,11 +56,16 @@ def init_params_(tensor: torch.Tensor, method: str = 'normal', init_loc: float = fan_in, fan_out = tensor.shape[-1], tensor.shape[-2] std = np.sqrt(2.0 / float(fan_in + fan_out)) a = np.sqrt(3.0) * std - return nn.init.uniform_(tensor, -a, a) + nn.init.uniform_(tensor, -a, a) elif method == 'xavier-normal': fan_in, fan_out = tensor.shape[-1], tensor.shape[-2] std = np.sqrt(2.0 / float(fan_in + fan_out)) - return nn.init.normal_(tensor, 0.0, std) + nn.init.normal_(tensor, 0.0, std) + elif method == 'stiefel': + assert tensor.shape[-2] <= tensor.shape[-1] + nn.init.normal_(tensor, init_loc, init_scale) + q, _ = torch.linalg.qr(tensor.transpose(-2, -1), mode='reduced') + tensor.copy_(q.transpose(-2, -1)) else: raise NotImplementedError(f"Unknown initialization method called {method}") diff --git a/src/pcs/layers/compute.py b/src/pcs/layers/compute.py index 2554f37..7d88c33 100644 --- a/src/pcs/layers/compute.py +++ b/src/pcs/layers/compute.py @@ -25,9 +25,9 @@ def __init__( class MonotonicComputeLayer(ComputeLayer, abc.ABC): def forward(self, x: torch.Tensor) -> torch.Tensor: - pass + raise NotImplementedError() class BornComputeLayer(ComputeLayer, abc.ABC): def forward(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - pass + raise NotImplementedError() diff --git a/src/pcs/layers/input.py b/src/pcs/layers/input.py index c259b71..5d44580 100644 --- a/src/pcs/layers/input.py +++ b/src/pcs/layers/input.py @@ -84,8 +84,10 @@ def __init__( num_states: int = 2, init_method: str = 'normal', init_scale: float = 1.0, - exp_reparam: bool = False + exp_reparam: bool = False, + l2norm: bool = False ): + assert not exp_reparam or not l2norm, "Only one between --exp-reparam and --l2norm can be set true" super().__init__(rg_nodes, num_components) self.num_states = num_states weight = torch.empty(self.num_variables, self.num_replicas, self.num_components, num_states) @@ -94,10 +96,16 @@ def __init__( weight = torch.log(weight) self.weight = nn.Parameter(weight, requires_grad=True) self.exp_reparam = exp_reparam + self.l2norm = l2norm self._ohe = num_states <= 256 def log_pf(self) -> Tuple[torch.Tensor, torch.Tensor]: - weight = torch.exp(self.weight) if self.exp_reparam else self.weight + if self.exp_reparam: + weight = torch.exp(self.weight) + elif self.l2norm: + weight = self.weight / torch.linalg.vector_norm(self.weight, ord=2, dim=2, keepdim=True) + else: + weight = self.weight w_si = torch.sign(weight.detach()) # (num_variables, num_replicas, num_components, num_states) w = safelog(torch.abs(weight)) # (num_variables, num_replicas, num_components, num_states) @@ -109,7 +117,12 @@ def log_pf(self) -> Tuple[torch.Tensor, torch.Tensor]: return z.unsqueeze(dim=0), z_si.unsqueeze(dim=0) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - weight = torch.exp(self.weight) if self.exp_reparam else self.weight + if self.exp_reparam: + weight = torch.exp(self.weight) + elif self.l2norm: + weight = self.weight / torch.linalg.vector_norm(self.weight, ord=2, dim=2, keepdim=True) + else: + weight = self.weight # x: (-1, num_vars) # self.weight: (num_vars, num_comps, num_states) @@ -121,7 +134,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: w_si = torch.sign(w.detach()) w = safelog(torch.abs(w)) else: - weight = self.weight.permute(0, 3, 1, 2) + weight = weight.permute(0, 3, 1, 2) w = weight[torch.arange(weight.shape[0], device=x.device), x] w_si = torch.sign(w.detach()) w = safelog(torch.abs(w)) diff --git a/src/pcs/models.py b/src/pcs/models.py index ad33807..bf1bb84 100644 --- a/src/pcs/models.py +++ b/src/pcs/models.py @@ -443,7 +443,6 @@ def eval_log_pf(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: def log_score(self, x: torch.Tensor) -> torch.Tensor: (x, x_si), ldj = self._eval_input(x) - return self._eval_layers(x, x_si) + ldj def log_marginal_score( diff --git a/src/scripts/benchmark.py b/src/scripts/benchmark.py index d058931..d1a0595 100644 --- a/src/scripts/benchmark.py +++ b/src/scripts/benchmark.py @@ -9,6 +9,7 @@ import numpy as np from matplotlib import rcParams from torch.utils.data import DataLoader, TensorDataset +from torch import optim from graphics.utils import setup_tueplots from pcs.models import PC @@ -51,6 +52,9 @@ parser.add_argument( '--exp-bubble-radius', type=float, default=1.75, help="The exponent for computing the bubble sizes" ) +parser.add_argument( + '--eval-backprop', action='store_true', default=False, help="Whether to benchmark also backpropagation" +) parser.add_argument( '--seed', type=int, default=42, help="The seed for reproducibility" ) @@ -109,6 +113,10 @@ def benchmark_model( def run_benchmark(data_loader: DataLoader, model: PC, burnin_iterations: int = 1, eval_pf: bool = False) -> Tuple[float, float]: + if args.eval_backprop: + # Setup losses and a dummy optimizer (only used to free gradient tensors) + optimizer = optim.SGD(model.parameters(), lr=0.01) + elapsed_times = list() gpu_memory_peaks = list() for batch_idx, batch in enumerate(data_loader): @@ -130,10 +138,15 @@ def run_benchmark(data_loader: DataLoader, model: PC, burnin_iterations: int = 1 lls = model.log_pf(return_input=False) else: lls = model.log_score(batch) + if args.eval_backprop: + loss = -lls.mean() + loss.backward(retain_graph=False) # Free the autodiff graph end.record() torch.cuda.synchronize(device) # Synchronize CUDA Kernels before measuring time #end_time = time.perf_counter() gpu_memory_peaks.append(from_bytes_to_gib(torch.cuda.max_memory_allocated(device))) + if args.eval_backprop: + optimizer.zero_grad() # Free gradients tensors gc.enable() # Enable GC again gc.collect() # Manual GC #elapsed_times.append(end_time - start_time) @@ -158,7 +171,8 @@ def run_benchmark(data_loader: DataLoader, model: PC, burnin_iterations: int = 1 # Set device and the seed device = torch.device(args.device) set_global_seed(args.seed) - torch.set_grad_enabled(False) + if not args.eval_backprop: + torch.set_grad_enabled(False) # Setup the data set metadata, (train_dataloader, valid_dataloader, test_dataloader) = setup_data_loaders( @@ -167,7 +181,7 @@ def run_benchmark(data_loader: DataLoader, model: PC, burnin_iterations: int = 1 dataset = train_dataloader.dataset.tensors[0].numpy() nrows, ncols = 1, 2 - setup_tueplots(nrows, ncols) + setup_tueplots(nrows, ncols, hw_ratio=0.5) fig, ax = plt.subplots(nrows, ncols, sharey=True) def _bubble_size(s, inverse=False): return bubble_size( @@ -226,7 +240,7 @@ def _bubble_size(s, inverse=False): print(f"Plotting results for {m}") #desc = format_model_name(m) - desc = r"$Z \ (\pm^2)$" if ss['eval_pf'] else r"$c(\mathbf{X}) \ (\pm)$" + desc = r"$Z = \int c^2(\mathbf{x})\mathrm{d}\mathbf{x}$" if ss['eval_pf'] else r"$c(\mathbf{x})$" bench_bs_results = list(filter(lambda t: np.isfinite(t[0]), bench_bs_results)) bench_nc_results = list(filter(lambda t: np.isfinite(t[0]), bench_nc_results)) print(bench_bs_results) @@ -288,5 +302,7 @@ def _bubble_size(s, inverse=False): handletextpad=1.0 ) os.makedirs(os.path.join('figures', 'benchmarks'), exist_ok=True) - plt.savefig(os.path.join('figures', 'benchmarks', f'benchmark-{args.dataset}.pdf')) - + filename = f'benchmark-{args.dataset}' + if args.eval_backprop: + filename = f'{filename}-backprop' + plt.savefig(os.path.join('figures', 'benchmarks', f'{filename}.pdf')) diff --git a/src/scripts/benchmark_vars.py b/src/scripts/benchmark_vars.py new file mode 100644 index 0000000..500d865 --- /dev/null +++ b/src/scripts/benchmark_vars.py @@ -0,0 +1,412 @@ +import argparse +import gc +import os +from collections import defaultdict +from typing import Tuple + +import matplotlib.pyplot as plt +import torch +import numpy as np +from matplotlib import rcParams +from torch.utils.data import DataLoader, TensorDataset +from torch import optim + +from graphics.utils import setup_tueplots +from pcs.models import PC +from scripts.utils import set_global_seed, setup_data_loaders, setup_model +from pcs.utils import num_parameters + +parser = argparse.ArgumentParser( + description="Benchmark for squared circuits" +) +parser.add_argument( + '--datasets', type=str, required=True, help="The evaluation datasets, separated by space" +) +parser.add_argument( + '--num-iterations', type=int, default=1, help="The number of iterations per benchmark" +) +parser.add_argument( + '--burnin-iterations', type=int, default=1, help="Burnin iterations (additional to --num-iterations)" +) +parser.add_argument( + '--device', type=str, default='cpu', help="The device id" +) +parser.add_argument( + '--batch-size', type=int, default=512, help="The batch size to use" +) +parser.add_argument( + '--num-components', type=int, default=512, help="The layer dimensionality" +) +parser.add_argument( + '--min-bubble-radius', type=float, default=40.0, help="Bubble sizes minimum" +) +parser.add_argument( + '--scale-bubble-radius', type=float, default=1.0, help="Bubble sizes scaler" +) +parser.add_argument( + '--exp-bubble-radius', type=float, default=1.75, help="The exponent for computing the bubble sizes" +) +parser.add_argument( + '--specific-hparams', type=str, default="", + help="Specific hyperparameters (separated by space) per model (separated by dash) per dataset (separated by semicolon)" +) +# +# e.g., --specific-hparams "num_components=1024 batch_size=2048-num_components=1024 batch_size=512;num_components=128 batch_size=512-num_components=256 batch_size=512;num_components=32 batch_size=512-num_components=32 batch_size=512;num_components=512 batch_size=512-num_components=128 batch_size=512" +# +parser.add_argument( + '--eval-backprop', action='store_true', default=False, help="Whether to benchmark also backpropagation" +) +parser.add_argument( + '--seed', type=int, default=42, help="The seed for reproducibility" +) + + +def from_bytes_to_gib(bytes: int) -> float: + return bytes / (1024.0 * 1024.0 * 1024.0) + + +def bubble_size(s: float, a: float = 0.0, m: float = 1.0, p: float = 2.0, inverse: bool = False) -> float: + if inverse: + return ((s - a) ** (1.0 / p)) / m + return a + ((m * s) ** p) + + +def format_model_name(m: str, exp_reparam: bool = False) -> str: + if m == 'MonotonicPC': + return r"$+$" + elif m == 'BornPC': + if exp_reparam: + return r"$+^2$" + else: + return r"$\pm^2$" + elif m == 'MonotonicHMM': + return r"$+$" + elif m == 'BornHMM': + if exp_reparam: + return r"$+^2$" + else: + return r"$\pm^2$" + assert False + + +def benchmark_model( + model: PC, + dataset: np.ndarray, + batch_size: int, + num_iterations: int, + burnin_iterations: int, + eval_pf: bool = False +): + # Setup the data loader + total_num_iterations = burnin_iterations + num_iterations + while batch_size * total_num_iterations >= len(dataset): + dataset = np.concatenate([dataset, dataset], axis=0) + # assert batch_size * total_num_iterations < len(dataset), "Number of iterations is too large for this dataset and batch size" + ordering = np.random.permutation(len(dataset)) + dataset = dataset[ordering] + dataset = TensorDataset(torch.from_numpy(dataset[:batch_size * total_num_iterations])) + data_loader = DataLoader(dataset, batch_size, drop_last=True) + try: + mu_time, mu_memory = run_benchmark(data_loader, model, burnin_iterations=burnin_iterations, eval_pf=eval_pf) + except torch.cuda.OutOfMemoryError: + mu_time, mu_memory = np.nan, np.nan + return mu_time, mu_memory + + +def run_benchmark(data_loader: DataLoader, model: PC, burnin_iterations: int = 1, eval_pf: bool = False) -> Tuple[ + float, float]: + if args.eval_backprop: + # Setup losses and a dummy optimizer (only used to free gradient tensors) + optimizer = optim.SGD(model.parameters(), lr=0.01) + + elapsed_times = list() + gpu_memory_peaks = list() + for batch_idx, batch in enumerate(data_loader): + if isinstance(batch, (tuple, list)): + batch = batch[0] + # Run GC manually and then disable it + gc.collect() + gc.disable() + # Reset peak memory usage statistics + torch.cuda.reset_peak_memory_stats(device) + # torch.cuda.synchronize(device) # Synchronize CUDA operations + batch = batch.to(device) + # torch.cuda.synchronize(device) # Make sure the batch is already loaded (do not take into account this!) + # start_time = time.perf_counter() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + if eval_pf: + lls = model.log_pf(return_input=False) + else: + lls = model.log_score(batch) + if args.eval_backprop: + loss = -lls.mean() + loss.backward(retain_graph=False) # Free the autodiff graph + end.record() + torch.cuda.synchronize(device) # Synchronize CUDA Kernels before measuring time + # end_time = time.perf_counter() + gpu_memory_peaks.append(from_bytes_to_gib(torch.cuda.max_memory_allocated(device))) + if args.eval_backprop: + optimizer.zero_grad() # Free gradients tensors + gc.enable() # Enable GC again + gc.collect() # Manual GC + # elapsed_times.append(end_time - start_time) + elapsed_times.append(start.elapsed_time(end) * 1e-3) + + # Discard burnin iterations and compute averages + elapsed_times = elapsed_times[burnin_iterations:] + gpu_memory_peaks = gpu_memory_peaks[burnin_iterations:] + mu_time = np.mean(elapsed_times).item() + print(f"Mean time: {mu_time} -- Std. time: {np.std(elapsed_times)}") + mu_memory = np.mean(gpu_memory_peaks).item() + return mu_time, mu_memory + + +def entry_uniform_hparams_configuration() -> Tuple[dict, list]: + bench_results = defaultdict(dict) + datasets = args.datasets.split() + num_variables = list() + for dataset in datasets: + # Setup the data set + metadata, (train_dataloader, valid_dataloader, test_dataloader) = setup_data_loaders( + dataset, 'datasets', batch_size=1 + ) + x_data = train_dataloader.dataset.tensors[0].numpy() + num_variables.append(x_data.shape[1]) + + for idx, m in enumerate(models): + print(f"Benchmarking {m} ...") + model = setup_model( + m, dataset_metadata=metadata, rg_type='random', + rg_replicas=8, rg_depth=-1, num_components=num_components, + compute_layer='cp', init_method='uniform', init_scale=1.0, + seed=args.seed + ) + + num_params = num_parameters(model) + print(f"Model architecture:\n{model}") + print(f"Number of parameters: {num_params}") + model.to(device) + + if 'c' not in bench_results[m]: + bench_results[m]['c'] = dict() + bench_results[m]['Z'] = dict() + + mu_time, mu_memory = benchmark_model( + model, x_data, batch_size=batch_size, + num_iterations=args.num_iterations, burnin_iterations=args.burnin_iterations, + eval_pf=False + ) + bench_results[m]['c'][dataset] = (mu_time, mu_memory) + + mu_time, mu_memory = benchmark_model( + model, x_data, batch_size=batch_size, + num_iterations=args.num_iterations, burnin_iterations=args.burnin_iterations, + eval_pf=True + ) + bench_results[m]['Z'][dataset] = (mu_time, mu_memory) + + del model + return bench_results, num_variables + + +def entry_specific_hparams_configuration(hparams_conf: dict) -> Tuple[dict, list]: + log_likelihoods = { + 'gas': {'MonotonicPC': 5.56, 'BornPC': 10.98}, + 'hepmass': {'MonotonicPC': -22.45, 'BornPC': -20.41}, + 'miniboone': {'MonotonicPC': -32.11, 'BornPC': -26.92}, + 'bsds300': {'MonotonicPC': 123.30, 'BornPC': 128.38}, + } + bench_results = defaultdict(dict) + datasets = args.datasets.split() + num_variables = list() + for dataset, hps_conf in zip(datasets, hparams_conf): + # Setup the data set + metadata, (train_dataloader, valid_dataloader, test_dataloader) = setup_data_loaders( + dataset, 'datasets', batch_size=1 + ) + x_data = train_dataloader.dataset.tensors[0].numpy() + num_variables.append(x_data.shape[1]) + + for idx, (m, hps) in enumerate(zip(models, hps_conf)): + print(f"Benchmarking {m} ...") + other_hps = dict() + if dataset == 'bsds300': + other_hps['splines'] = True + other_hps['spline_knots'] = 512 + + model = setup_model( + m, dataset_metadata=metadata, rg_type='random', + rg_replicas=8, rg_depth=-1, num_components=int(hps['num_components']), + compute_layer='cp', init_method='uniform', init_scale=1.0, + seed=args.seed, **other_hps + ) + + num_params = num_parameters(model) + print(f"Model architecture:\n{model}") + print(f"Number of parameters: {num_params}") + model.to(device) + + if 'c' not in bench_results[m]: + bench_results[m]['c'] = dict() + bench_results[m]['Z'] = dict() + bench_results[m]['ll'] = dict() + bench_results[m]['ll'][dataset] = log_likelihoods[dataset][m] + + mu_time, mu_memory = benchmark_model( + model, x_data, batch_size=int(hps['batch_size']), + num_iterations=args.num_iterations, burnin_iterations=args.burnin_iterations, + eval_pf=False + ) + bench_results[m]['c'][dataset] = (mu_time, mu_memory) + + mu_time, mu_memory = benchmark_model( + model, x_data, batch_size=int(hps['batch_size']), + num_iterations=args.num_iterations, burnin_iterations=args.burnin_iterations, + eval_pf=True + ) + bench_results[m]['Z'][dataset] = (mu_time, mu_memory) + + del model + return bench_results, num_variables + + +if __name__ == '__main__': + args = parser.parse_args() + batch_size = args.batch_size + num_components = args.num_components + models = ['MonotonicPC', 'BornPC'] + + # Set device and the seed + device = torch.device(args.device) + set_global_seed(args.seed) + if not args.eval_backprop: + torch.set_grad_enabled(False) + + + def _bubble_size(s, inverse=False): + return bubble_size( + s, a=args.min_bubble_radius, + m=args.scale_bubble_radius, p=args.exp_bubble_radius, + inverse=inverse + ) + + + nrows, ncols = 1, 2 + setup_tueplots(nrows, ncols, hw_ratio=0.5) + fig, ax = plt.subplots(nrows, ncols, sharey=True) + + if args.specific_hparams: + hps_per_dataset = args.specific_hparams.split(';') + # I have no idea what I am doing, it's functional magic + hps_per_dataset_model = list(map(lambda hps: hps.split('-'), hps_per_dataset)) + hparams_conf = list( + map(lambda hps: list(map(lambda h: dict(map(lambda x: tuple(x.split('=')), h.split(' '))), hps)), + hps_per_dataset_model)) + print(hparams_conf) + bench_results, num_variables = entry_specific_hparams_configuration(hparams_conf) + else: + bench_results, num_variables = entry_uniform_hparams_configuration() + + print(bench_results) + print(f"Plotting results") + + scatter_plots = dict() + for idx, (m, br) in enumerate(bench_results.items()): + eval_c_results = list(map(lambda d: br['c'][d], br['c'].keys())) + eval_Z_results = list(map(lambda d: br['Z'][d], br['Z'].keys())) + print(eval_c_results) + print(eval_Z_results) + + if args.specific_hparams: + eval_ll_results = list(map(lambda d: br['ll'][d], br['ll'].keys())) + print(eval_ll_results) + + eval_c_results_time = list(map(lambda t: t[0], eval_c_results)) + eval_Z_results_time = list(map(lambda t: t[0], eval_Z_results)) + + if m == 'MonotonicPC': + desc_c = r'$c(\mathbf{x})$' + desc_Z = r'$Z = \int c(\mathbf{x}) \mathrm{d}\mathbf{x}$' + elif m == 'BornPC': + desc_c = r'$c^2(\mathbf{x})$' + desc_Z = r'$Z = \int c^2(\mathbf{x}) \mathrm{d}\mathbf{x}$' + else: + assert False + + sc_c = ax[0].scatter( + num_variables, eval_c_results_time, + color=f'C{idx}', alpha=.5, s=list(map(lambda t: _bubble_size(t[1]), eval_c_results)), + marker='o', label=desc_c + ) + ax[0].scatter( + num_variables, eval_c_results_time, + color='k', alpha=.6, s=1, marker='o' + ) + if args.specific_hparams: + for xi, yi, ll in zip(num_variables, eval_c_results_time, eval_ll_results): + xytext = (0, 1) if m == 'BornPC' else (0, -6) + ax[0].annotate(f"{ll}", xy=(xi, yi), fontsize=6, xytext=xytext, textcoords='offset points') + + sc_Z = ax[1].scatter( + num_variables, eval_Z_results_time, + color=f'C{idx}', alpha=.5, s=list(map(lambda t: _bubble_size(t[1]), eval_Z_results)), + marker='o', label=desc_Z + ) + ax[1].scatter( + num_variables, eval_Z_results_time, + color='k', alpha=.6, s=1, marker='o' + ) + if args.specific_hparams: + for xi, yi, ll in zip(num_variables, eval_Z_results_time, eval_ll_results): + xytext = (0, 1) if m == 'BornPC' else (0, -6) + ax[1].annotate(f"{ll}", xy=(xi, yi), fontsize=6, xytext=xytext, textcoords='offset points') + + scatter_plots[m] = {'c': sc_c, 'Z': sc_Z} + + ax[0].set_ylabel('Time per batch ($s$)') + ax[0].annotate(r'$|\mathbf{X}|$', xy=(1, 0), xytext=(1, -1 * rcParams['xtick.major.pad']), ha='right', va='top', + xycoords='axes fraction', textcoords='offset points') + ax[1].annotate(r'$|\mathbf{X}|$', xy=(1, 0), xytext=(1, -1 * rcParams['xtick.major.pad']), ha='right', va='top', + xycoords='axes fraction', textcoords='offset points') + ax[0].set_axisbelow(True) + ax[1].set_axisbelow(True) + # ax[0].set_xscale('log') + # ax[1].set_xscale('log') + ax[0].set_yscale('log') + ax[1].set_yscale('log') + # ax[1].set_xticks([2 * (10 ** 5), 10 ** 6]) + ax[0].margins(x=0.2, y=0.275) + ax[1].margins(x=0.2, y=0.275) + #ax[0].set_ylim(bottom=-0.05) + #ax[1].set_ylim(bottom=-0.05) + ax[0].grid(linestyle='--', alpha=0.3, linewidth=.5) + ax[1].grid(linestyle='--', alpha=0.3, linewidth=.5) + c_legend_loc = 'upper center' if args.specific_hparams else 'upper left' + c_legend = ax[0].legend( + loc=c_legend_loc, + # bbox_to_anchor=(1.0, 1.0), + labelspacing=0.4, + framealpha=0.4 + ) + + for i in range(len(c_legend.legend_handles)): + c_legend.legend_handles[i].set_sizes([20]) + ax[0].add_artist(c_legend) + z_legend_loc = 'upper center' if args.specific_hparams else 'upper left' + z_legend = ax[1].legend(loc=z_legend_loc, framealpha=0.4) + ax[0].text(-0.62, 0.45, "GPU Memory (GiB)", rotation=90, va='center', transform=ax[0].transAxes) + ax[0].legend( + loc='upper right', bbox_to_anchor=(-0.24, 1.05), + labelspacing=1.4 if args.specific_hparams else 2.2, frameon=False, + *scatter_plots[models[-1]]['Z'].legend_elements( + prop='sizes', func=lambda s: _bubble_size(s, inverse=True), + alpha=0.6, fmt="{x:.0f}", num=4 + ) + ) + os.makedirs(os.path.join('figures', 'benchmarks'), exist_ok=True) + filename = 'benchmark-vars-hparams' if args.specific_hparams else 'benchmark-vars' + if args.eval_backprop: + filename = f'{filename}-backprop' + plt.savefig(os.path.join('figures', 'benchmarks', f'{filename}.pdf')) diff --git a/src/scripts/engine.py b/src/scripts/engine.py index a22d7e0..c41d340 100644 --- a/src/scripts/engine.py +++ b/src/scripts/engine.py @@ -16,14 +16,12 @@ from pcs.layers import COMPUTE_LAYERS from pcs.optimizers import OPTIMIZERS_NAMES, setup_optimizer from pcs.models import PCS_MODELS, PC, TensorizedPC -from graphics.utils import array_to_image from region_graph import REGION_GRAPHS from pcs.utils import num_parameters from scripts.logger import Logger from scripts.utils import set_global_seed, evaluate_model_log_likelihood,\ bits_per_dimension, perplexity, \ build_run_id, setup_data_loaders, setup_model, setup_experiment_path, get_git_revision_hash -from graphics.distributions import bivariate_pdf_heatmap, bivariate_pmf_heatmap class Engine: @@ -43,12 +41,10 @@ def __init__(self, args: Namespace): if args.exp_alias: run_group = f'{run_group}-{args.exp_alias}' exp_path = setup_experiment_path(args.dataset, args.model, args.exp_alias, run_id) - if args.save_checkpoint: - kwargs['checkpoint_path'] = os.path.join(args.checkpoint_path, exp_path) - os.makedirs(kwargs['checkpoint_path'], exist_ok=True) - if args.tboard_path: - kwargs['tboard_path'] = os.path.join(args.tboard_path, exp_path) - os.makedirs(kwargs['tboard_path'], exist_ok=True) + kwargs['checkpoint_path'] = os.path.join(args.checkpoint_path, exp_path) + os.makedirs(kwargs['checkpoint_path'], exist_ok=True) + kwargs['tboard_path'] = os.path.join(args.tboard_path, exp_path) + os.makedirs(kwargs['tboard_path'], exist_ok=True) if args.wandb_path: kwargs['wandb_path'] = args.wandb_path kwargs['wandb_kwargs'] = { @@ -67,10 +63,11 @@ def __init__(self, args: Namespace): 'valid': None, 'test': None } - + self.model: Optional[Union[PC, Flow]] = None self.optimizer: Optional[torch.optim.Optimizer] = None self.scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None + self._log_distribution = self.args.log_distribution def shutdown(self): self.logger.close() @@ -94,6 +91,7 @@ def _hparams(self) -> Dict[str, Any]: 'spline_order': self.args.spline_order, 'spline_knots': self.args.spline_knots, 'exp_reparam': self.args.exp_reparam, + 'l2norm': self.args.l2norm, 'optimizer': self.args.optimizer, 'learning_rate': self.args.learning_rate, 'batch_size': self.args.batch_size, @@ -191,19 +189,6 @@ def _eval_step( test_ppl = perplexity(test_avg_ll, self.metadata['num_variables']) self.logger.info(f"[{self.args.dataset}] Epoch {epoch_idx}, Test ppl: {test_ppl:.03f}") self.logger.log_scalar('Test/ppl', test_ppl, step=epoch_idx) - - if self.metadata['type'] == 'artificial' or\ - (self.metadata['type'] == 'categorical' and self.metadata['num_variables'] == 2): - xlim, ylim = self.metadata['domains'] - if self.args.heavy_logging: - if self.args.discretize: - pmf_hmap = bivariate_pmf_heatmap(self.model, xlim, ylim, device=self._device) - self.logger.log_image(f'{self.args.model} PMF', array_to_image(pmf_hmap, vmin=0.0), step=epoch_idx) - self.logger.save_array(pmf_hmap, 'pmf.npy') - else: - pdf_hmap = bivariate_pdf_heatmap(self.model, xlim, ylim, device=self._device) - self.logger.log_image(f'{self.args.model} PDF', array_to_image(pdf_hmap, vmin=0.0), step=epoch_idx) - self.logger.save_array(pdf_hmap, 'pdf.npy') metrics['best_valid_epoch'] = epoch_idx metrics['best_valid_avg_ll'] = valid_avg_ll metrics['best_valid_std_ll'] = valid_std_ll @@ -258,6 +243,8 @@ def run(self): self.dataloaders['train'] = train_dataloader self.dataloaders['valid'] = valid_dataloader self.dataloaders['test'] = test_dataloader + self._log_distribution &= self.metadata['type'] == 'artificial' or \ + (self.metadata['type'] == 'categorical' and self.metadata['num_variables'] == 2) self.logger.info(f"Number of variables: {self.metadata['num_variables']}") # Initialize the model @@ -269,7 +256,7 @@ def run(self): exp_reparam=self.args.exp_reparam, binomials=self.args.binomials, splines=self.args.splines, spline_order=self.args.spline_order, spline_knots=self.args.spline_knots, init_method=self.args.init_method, init_scale=self.args.init_scale, - dequantize=self.args.dequantize, seed=self.args.seed + dequantize=self.args.dequantize, l2norm=self.args.l2norm, seed=self.args.seed ) # Instantiate the optimizer @@ -301,23 +288,22 @@ def run(self): checkpoint_path = self.args.load_checkpoint_path else: checkpoint_path = self.args.checkpoint_path - + # If alternate checkpoint hparams given, replace values in hparams from CL checkpoint_args = copy(self.args) for hp in self.args.checkpoint_hparams.split(';'): hp_name, hp_value = hp.split('=') checkpoint_args.__setattr__(hp_name.replace('-', '_'), hp_value) - + checkpoint_run_id = build_run_id(checkpoint_args) checkpoint_exp_path = setup_experiment_path( checkpoint_args.dataset, checkpoint_args.model, checkpoint_args.exp_alias, checkpoint_run_id) - + # Loading the model checkpoint_filepath = os.path.join(checkpoint_path, checkpoint_exp_path, 'model.pt') - state_dict = torch.load(checkpoint_filepath, map_location=self._device) + state_dict = torch.load(checkpoint_filepath, map_location='cpu') self.model.load_state_dict(state_dict['weights']) self.model.to(self._device) - del state_dict # TODO: figure out what this operation is doing if 'Born' in self.args.model and 'Monotonic' in checkpoint_args.model: @@ -334,6 +320,7 @@ def run(self): p.data.exp_() else: self.optimizer.load_state_dict(state_dict['opt']) + del state_dict self.logger.info(f"Checkpoint loaded from {checkpoint_filepath}") metrics = self._eval_step(0, metrics) @@ -368,13 +355,14 @@ def run(self): gamma=self.args.amount_lr_decay, verbose=True ) - if self.args.heavy_logging: - if self.metadata['type'] == 'artificial' or\ - (self.metadata['type'] == 'categorical' and self.metadata['num_variables'] == 2): - self.logger.log_image('Ground Truth', self.metadata['hmap']) + if self._log_distribution: + self.logger.save_array(self.metadata['hmap'], 'gt.npy') + self.logger.log_distribution( + self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device) # The train loop diverged = False + opt_counter = 0 for epoch_idx in range(1, self.args.num_epochs + 1): self.model.train() running_average_loss = 0.0 @@ -398,6 +386,14 @@ def run(self): self.logger.info(f"[{self.args.dataset}] Loss is not finite") diverged = True break + if opt_counter % ( + max(1, int(1e-1 * self.args.log_frequency)) if epoch_idx == 1 + else (max(1, int(2e-1 * self.args.log_frequency)) if epoch_idx == 2 + else self.args.log_frequency)) == 0: + if self._log_distribution: + self.logger.log_distribution( + self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device) + opt_counter += 1 if diverged: self.logger.info(f"Diverged, exiting ...") break diff --git a/src/scripts/experiment.py b/src/scripts/experiment.py index dcaa387..fb8af6d 100644 --- a/src/scripts/experiment.py +++ b/src/scripts/experiment.py @@ -21,8 +21,10 @@ parser.add_argument('--device', default='cpu', type=str, help="The Torch device to use") parser.add_argument('--data-path', default='datasets', type=str, help="The data root path") parser.add_argument('--tboard-path', default='', type=str, help="The Tensorboard path, empty to disable") -parser.add_argument('--disable-heavy-logging', dest='heavy_logging', action='store_false', default=True, - help="Whether to disable heavy logging procedures") +parser.add_argument('--log-distribution', action='store_true', default=False, + help="Whether to log the learned distribution") +parser.add_argument('--log-frequency', default=100, + help="The frequency for logging distributions") parser.add_argument('--wandb-path', default='', type=str, help="The W&B path, empty to disable") parser.add_argument('--wandb-project', default='born-pcs', type=str, help="The W&B project") parser.add_argument('--wandb-sweeps', type=int, default=0, help="How many hyperparameters to sweep, 0 to disable") @@ -64,6 +66,8 @@ help="The amount of noise to apply relative to the least squares initialization method") parser.add_argument('--exp-reparam', action='store_true', default=False, help="Whether to reparameterize the parameters of BornPCs via exponentiation") +parser.add_argument('--l2norm', action='store_true', default=False, + help="Wether to apply L2 norm to the parameters (valid only for HMMs)") parser.add_argument('--init-method', choices=INIT_METHODS, default=INIT_METHODS[0], help="Parameters initialisers") parser.add_argument('--init-scale', type=float, default=1.0, help="The initialization scale for the layers") parser.add_argument('--optimizer', choices=OPTIMIZERS_NAMES, default=OPTIMIZERS_NAMES[0], help="Optimiser to use") diff --git a/src/scripts/logger.py b/src/scripts/logger.py index 8bb4d97..3d10c85 100644 --- a/src/scripts/logger.py +++ b/src/scripts/logger.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Dict, Any, Union, List +from typing import Optional, Dict, Any, Union, List, Tuple import numpy as np import torch @@ -8,6 +8,9 @@ from PIL import Image as pillow +from graphics.distributions import bivariate_pmf_heatmap, bivariate_pdf_heatmap +from pcs.models import PC + class Logger: def __init__( @@ -30,6 +33,9 @@ def __init__( wandb_kwargs = dict() self._setup_wandb(wandb_path, **wandb_kwargs) + self._logged_distributions = list() + self._logged_wcoords = list() + @property def has_graphical_endpoint(self) -> bool: return self._tboard_writer is not None or wandb.run @@ -99,7 +105,25 @@ def log_hparams( if wandb.run: wandb.run.summary.update(metric_dict) + def log_distribution( + self, + model: PC, + discretized: bool, + lim: Tuple[Tuple[Union[float, int], Union[float, int]], Tuple[Union[float, int], Union[float, int]]], + device: Optional[Union[str, torch.device]] = None + ): + xlim, ylim = lim + if discretized: + dist_hmap = bivariate_pmf_heatmap(model, xlim, ylim, device=device) + else: + dist_hmap = bivariate_pdf_heatmap(model, xlim, ylim, device=device) + self._logged_distributions.append(dist_hmap.astype(np.float32, copy=False)) + def close(self): + if self._logged_distributions: + self.save_array(np.stack(self._logged_distributions, axis=0), 'distribution.npy') + if self._logged_wcoords: + self.save_array(np.stack(self._logged_wcoords, axis=0), 'wcoords.npy') if self._tboard_writer is not None: self._tboard_writer.close() if wandb.run: @@ -116,3 +140,12 @@ def save_image(self, data: np.ndarray, filepath: str): def save_array(self, array: np.ndarray, filepath: str): if self.checkpoint_path: np.save(os.path.join(self.checkpoint_path, filepath), array) + + def load_array(self, filepath: str) -> Optional[np.ndarray]: + if self.checkpoint_path: + try: + array = np.load(os.path.join(self.checkpoint_path, filepath)) + except OSError: + return None + return array + return None diff --git a/src/scripts/plots/bsplines.py b/src/scripts/plots/bsplines.py index 02f6370..9b5fddd 100644 --- a/src/scripts/plots/bsplines.py +++ b/src/scripts/plots/bsplines.py @@ -30,12 +30,13 @@ fig, ax = plt.subplots(num_rows, num_cols) for i in range(y.shape[1]): - m = y[:, i] > 3e-4 + m = y[:, i] > 7.5e-4 ax.plot(x[m], y[m, i], linewidth=2, alpha=0.6) - w = 1.1 + torch.rand(y.shape[1], dtype=torch.float64) + #w = 1.1 + torch.rand(y.shape[1], dtype=torch.float64) + w = -0.2 + torch.randn(y.shape[1], dtype=torch.float64) z = torch.sum(y * w.unsqueeze(dim=0), dim=1) - ax.plot(x, z, c='k', linewidth=2) + ax.plot(x, z, c='k', alpha=0.7, linewidth=1.5) ax.set_xticks(np.unique(knots.astype(np.float32))) ax.margins(0.1) diff --git a/src/scripts/plots/gpt2dist/lines.py b/src/scripts/plots/gpt2dist/lines.py index 4efa8be..0d3f628 100644 --- a/src/scripts/plots/gpt2dist/lines.py +++ b/src/scripts/plots/gpt2dist/lines.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +from scipy import stats from matplotlib import pyplot as plt from matplotlib import rcParams @@ -13,13 +14,18 @@ parser = argparse.ArgumentParser( description="Plot metrics by number of parameters line graphs for the experiments on GPT2 distillation", ) -parser.add_argument('tboard_path', default='tboard-runs', type=str, help="The Tensorboard runs path") +parser.add_argument('tboard_path', default='tboard-runs/gpt2-commongen-grid', type=str, help="The Tensorboard runs path") parser.add_argument('--metric', default='avg_ll', help="The metric to plot") parser.add_argument('--legend', action='store_true', default=False) parser.add_argument('--train', action='store_true', default=False) parser.add_argument('--title', action='store_true', default=False) parser.add_argument('--median', action='store_true', default=False, help="Whether to plt min, median and max areas") +""" +python -m scripts.plots.gpt2dist.lines tboard-runs/gpt2-commongen-ihmm --median --title --train --legend & \ +python -m scripts.plots.gpt2dist.lines tboard-runs/gpt2-commongen-ihmm --median --title & +""" + def format_model(m: str, exp_reparam: bool = False) -> str: if m == 'MonotonicPC': @@ -61,7 +67,7 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame: if __name__ == '__main__': args = parser.parse_args() metric = 'Best/Train/' + args.metric if args.train else 'Best/Test/' + args.metric - df = retrieve_tboard_runs(args.tboard_path, metric) + df = retrieve_tboard_runs(args.tboard_path, metric, ignore_diverged=False) num_rows = 1 num_cols = 1 @@ -70,23 +76,27 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame: markers = ['o', 'D'] num_points = 6 - for k, model_name in enumerate(['MonotonicPC', 'BornPC']): + metrics = defaultdict(dict) + model_names = ['MonotonicPC', 'BornPC'] + for k, model_name in enumerate(model_names): model_df = df[df['model'] == model_name] if model_name == 'BornPC': rows_to_keep = {'init_method': ['uniform', 'positive-skewed-normal', 'normal'], 'learning_rate': [5e-3, 1e-2, 5e-2]} + model_df = model_df[model_df['init_scale'] == 0.1] else: rows_to_keep = {'init_method': ['uniform', 'dirichlet', 'log-normal'], 'learning_rate': [5e-3, 1e-2, 5e-2]} + model_df = model_df[model_df['init_scale'] == 1.0] if rows_to_keep is not None: for r, vs in rows_to_keep.items(): model_df = model_df[model_df[r].isin(vs)] + model_df.to_csv(f'{model_name}-gpt2commongen-results.csv', index=None) group_model_df = model_df.groupby(by=['init_method', 'learning_rate']) should_label = True - metrics = defaultdict(list) + metrics[model_name] = defaultdict(list) for j, hparam_df in group_model_df: ms, ps = hparam_df[metric].tolist(), hparam_df['num_components'].tolist() if len(np.unique(ms)) < num_points or len(np.unique(ps)) < num_points: continue - #hparam_df.to_csv(f'{model_name}-{j}.csv') ms = np.array(ms, dtype=np.float64) ps = np.array(ps, dtype=np.int64) sort_indices = np.argsort(ps) @@ -95,22 +105,45 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame: ps = ps[:num_points] ms = ms[:num_points] for p, m in zip(ps.tolist(), ms.tolist()): - metrics[p].append(m) + metrics[model_name][p].append(m) if not args.median: should_label = False label = f'{format_model(model_name)}' if should_label else None ax.plot(ps, ms, label=label, marker=markers[k], markersize=3, linewidth=1.5, linestyle='-', alpha=0.2, c=f'C{k}') if args.median: label = f'{format_model(model_name)}' if should_label else None - median_metrics = sorted(list(map(lambda x: (x[0], np.median(x[1])), metrics.items())), key=lambda x: x[0]) - bot_metrics = sorted(list(map(lambda x: (x[0], np.quantile(x[1], q=0.1)), metrics.items())), key=lambda x: x[0]) - top_metrics = sorted(list(map(lambda x: (x[0], np.quantile(x[1], q=0.8)), metrics.items())), key=lambda x: x[0]) + median_metrics = sorted(list(map( + lambda x: (x[0], np.median(x[1])), metrics[model_name].items())), + key=lambda x: x[0]) + bot_metrics = sorted(list(map( + lambda x: (x[0], np.quantile(x[1], q=0.05)), metrics[model_name].items())), + key=lambda x: x[0]) + top_metrics = sorted(list(map( + lambda x: (x[0], np.quantile(x[1], q=0.95)), metrics[model_name].items())), + key=lambda x: x[0]) median_ps, median_ms = zip(*median_metrics) _, bot_ms = zip(*bot_metrics) _, top_ms = zip(*top_metrics) ax.plot(median_ps, median_ms, label=label, marker=markers[k], markersize=3, linewidth=1.5, linestyle='-', alpha=0.8, c=f'C{k}') ax.fill_between(median_ps, bot_ms, top_ms, alpha=0.2, color=f'C{k}') + assert len(model_names) == 2 + model_a, model_b = model_names + spvalues = defaultdict(lambda: defaultdict(dict)) + for ts in ['mannwithneyu', 'ttest']: + for al in ['greater']: + for k in sorted(metrics[model_a].keys() & metrics[model_b].keys()): + lls_a = metrics[model_a][k] + lls_b = metrics[model_b][k] + if ts == 'mannwithneyu': + s, p = stats.mannwhitneyu(lls_b, lls_a, method='exact', alternative=al) + elif ts == 'ttest': + s, p = stats.ttest_ind(lls_b, lls_a, alternative=al) + else: + assert False, "Should not happen :(" + spvalues[ts][al][k] = (round(s, 3), round(p, 4)) + print(spvalues) + #if args.train: # gpt2_average_log_likelihood = -52.01387770076976 # ax.axhline(y=gpt2_average_log_likelihood, color='k', linewidth=2, linestyle='dotted', label='GPT2') @@ -127,9 +160,11 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame: ha='right', va='bottom', xycoords='axes fraction', textcoords='offset points') ax.set_xscale('log') if args.train: - ax.set_yticks([-80.0, -75.0, -70.0, -65.0]) + #ax.set_yticks([-80.0, -75.0, -70.0, -65.0]) + ax.set_yticks([-80.0, -75.0, -70.0]) else: - ax.set_yticks([-90.0, -85.0, -80.0, -75.0]) + #ax.set_yticks([-90.0, -85.0, -80.0, -75.0]) + ax.set_yticks([-85.0, -80.0, -75.0]) if args.legend: ax.legend() if args.title: diff --git a/src/scripts/plots/ring/distgif.py b/src/scripts/plots/ring/distgif.py new file mode 100644 index 0000000..98b0c97 --- /dev/null +++ b/src/scripts/plots/ring/distgif.py @@ -0,0 +1,95 @@ +import os +import argparse +from typing import Callable + +import matplotlib.cm as cm +import numpy as np +from PIL import Image as pillow +import cv2 + +parser = argparse.ArgumentParser( + description="Create GIF from distributions" +) +parser.add_argument('path', type=str, default="checkpoints/gaussian-ring") +parser.add_argument('--max-num-frames', type=int, default=150, help="The maximum number of frames") +parser.add_argument('--copy-frames', type=int, default=30, help="The number of frames to copy at the end") +parser.add_argument('--gif-size', type=int, default=256, help="The width and height of the gifs") +parser.add_argument('--duration', type=int, default=100, help="The duration of each frame in ms") +parser.add_argument('--drop-last-frames', type=int, default=0, help="The number of last frames to drop") + +""" +python -m scripts.plots.ring.distgif checkpoints/loss-landscape --drop-last-frames 164 +""" + + +if __name__ == '__main__': + def to_rgb(x: np.ndarray, cmap: cm.ScalarMappable, cmap_transform: Callable[[np.ndarray], np.ndarray]) -> np.ndarray: + #x = x[51:-50, 51:-50] + x = (cmap.to_rgba(cmap_transform(x)) * 255.0).astype(np.uint8)[..., :-1] + if x.shape[0] != args.gif_size or x.shape[1] != args.gif_size: + x = cv2.resize(x, dsize=(args.gif_size, args.gif_size), interpolation=cv2.INTER_CUBIC) + return x + + def to_rgb_image(x: np.ndarray) -> pillow.Image: + return pillow.fromarray(x, mode='RGB') + + print("Loading the GIF data ...") + + args = parser.parse_args() + checkpoint_paths = [ + f"{args.path}/ring/MonotonicPC/RGran_R1_K2_D1_Lcp_OAdam_LR0.005_BS64_IU", + f"{args.path}/ring/MonotonicPC/RGran_R1_K16_D1_Lcp_OAdam_LR0.005_BS64_IU", + f"{args.path}/ring/BornPC/RGran_R1_K2_D1_Lcp_OAdam_LR0.001_BS64_IN" + ] + labels = [ + 'Ground Truth', + 'GMM-2', + 'GMM-16', + 'NGMM-2' + ] + gt_array = np.load(os.path.join(checkpoint_paths[0], 'gt.npy')) + gt_array = np.broadcast_to(gt_array, (args.max_num_frames, gt_array.shape[0], gt_array.shape[1])) + arrays = map(lambda p: np.load(os.path.join(p, 'distribution.npy')), checkpoint_paths) + if args.drop_last_frames > 0: + arrays = map(lambda a: a[:-args.drop_last_frames], arrays) + arrays = [gt_array] + list(arrays) + + print("Constructing the GIF ...") + + num_frames = min(args.max_num_frames, min(len(a) for a in arrays)) + frames_idx = [np.linspace(0.0, 1.0, num=num_frames + 1, endpoint=True)[:-1] for _ in range(len(arrays))] + arrays_idx = list(map(lambda x: np.floor(x[0] * len(x[1])).astype(np.int64), zip(frames_idx, arrays))) + arrays = list(map(lambda x: x[1][x[0]], zip(arrays_idx, arrays))) + + cmap_transform = lambda x: np.tanh(2.0 + 0.8 * np.log(x)) + cmap_min = cmap_transform(min(np.min(a) for a in arrays)) + cmap_max = cmap_transform(max(np.max(a) for a in arrays)) + cmap = cm.ScalarMappable(cm.colors.Normalize(cmap_min, cmap_max), cmap='turbo') + arrays = map( + lambda a: np.array([to_rgb(a[i], cmap, cmap_transform) for i in range(num_frames)]), + arrays + ) + + caption_height = 48 + font, fontscale, thickness = cv2.FONT_HERSHEY_SIMPLEX, 1, 2 + arrays = map( + lambda x: np.concatenate([ + x[1], np.tile(cv2.putText( + np.full(fill_value=255, shape=(caption_height, x[1].shape[2], 3), dtype=np.uint8), + x[0], + (int(0.5 * (x[1].shape[2] - cv2.getTextSize(x[0], font, fontscale, thickness)[0][0])), + int(0.5 * (caption_height + cv2.getTextSize(x[0], font, fontscale, thickness)[0][1]))), + font, fontscale, (16, 16, 16), thickness, cv2.LINE_AA), reps=(num_frames, 1, 1, 1)) + ], axis=1), zip(labels, arrays) + ) + gif_images = np.concatenate(list(arrays), axis=2) + + print("Saving GIF to file ...") + + gif_iterator = ( + (to_rgb_image(gif_images[i]) if i < len(gif_images) else to_rgb_image(gif_images[-1])) + for i in range(len(gif_images) + args.copy_frames) + ) + img = next(gif_iterator) + with open(os.path.join('figures', 'gaussian-ring', f'learning-distributions.gif'), 'wb') as fp: + img.save(fp=fp, format='GIF', append_images=gif_iterator, save_all=True, duration=args.duration, loop=0, optimize=True) diff --git a/src/scripts/plots/ring/ellipses.py b/src/scripts/plots/ring/ellipses.py new file mode 100644 index 0000000..417cc84 --- /dev/null +++ b/src/scripts/plots/ring/ellipses.py @@ -0,0 +1,214 @@ +import argparse +import os.path +from typing import Optional + +import matplotlib as mpl +import numpy as np +from scipy import special +import torch +from matplotlib import pyplot as plt +from sklearn.preprocessing import StandardScaler + +from datasets.loaders import load_artificial_dataset +from graphics.distributions import kde_samples_hmap +from graphics.utils import setup_tueplots +from pcs.models import TensorizedPC, PC, MonotonicPC +from scripts.utils import setup_model, setup_data_loaders + +parser = argparse.ArgumentParser( + description="PDFs plotter" +) +parser.add_argument('--checkpoint-path', default='checkpoints', type=str, help="The checkpoints path") +parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title") + + +def ring_kde() -> np.ndarray: + splits = load_artificial_dataset('ring', num_samples=50000, dtype=np.dtype(np.float64)) + data = np.concatenate(splits, axis=0) + scaler = StandardScaler() + data = scaler.fit_transform(data) + data_min, data_max = np.min(data, axis=0), np.max(data, axis=0) + drange = np.abs(data_max - data_min) + data_min, data_max = (data_min - drange * 0.05), (data_max + drange * 0.05) + xlim, ylim = [(data_min[i], data_max[i]) for i in range(len(data_min))] + return kde_samples_hmap(data, xlim=xlim, ylim=ylim, bandwidth=0.16) + + +def format_model_name(m: str, num_components: int) -> str: + if m == 'MonotonicPC': + return f"GMM ($K \! = \! {num_components}$)" + elif m == 'BornPC': + return f"NGMM ($K \! = \! {num_components}$)" + return m + + +def load_mixture( + model_name: str, + exp_id_fmt: str, + num_components: int, + learning_rate: float = 5e-3, + batch_size: int = 64 +) -> TensorizedPC: + metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) + model: TensorizedPC = setup_model(model_name, metadata, num_components=num_components) + exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) + filepath = os.path.join(args.checkpoint_path, 'gaussian-ring', 'ring', model_name, exp_id, 'model.pt') + state_dict = torch.load(filepath, map_location='cpu') + model.load_state_dict(state_dict['weights']) + return model + + +def load_pdf( + model: str, + exp_id_fmt: str, + num_components, + learning_rate: float = 5e-3, + batch_size: int = 64 +) -> np.ndarray: + exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) + filepath = os.path.join(args.checkpoint_path, 'gaussian-ring', 'ring', model, exp_id, 'pdf.npy') + return np.load(filepath) + + +def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes): + mus = mixture.input_layer.mu[0, :, 0, :].detach().numpy() + covs = np.exp(2 * mixture.input_layer.log_sigma[0, :, 0, :].detach().numpy()) + num_components = mus.shape[-1] + mix_weights = mixture.layers[-1].weight[0, 0].detach().numpy() + if isinstance(mixture, MonotonicPC): + mix_weights = special.softmax(mix_weights) + mix_weights = mix_weights / np.max(mix_weights) + else: + # assert num_components == 2 + # mix_weights = np.array([mix_weights[0] ** 2, mix_weights[1] ** 2, 2.0 * mix_weights[0] * mix_weights[1]]) + # # Products of Gaussian pdfs + # new_covs = np.array([ + # covs[:, 0] / 2.0, + # covs[:, 1] / 2.0, + # (covs[:, 0] * covs[:, 1]) / (covs[:, 0] + covs[:, 1]) + # ]).T # New covariances + # new_mus = np.array([ + # mus[:, 0], + # mus[:, 1], + # new_covs[:, 2] * (mus[:, 0] / covs[:, 0] + mus[:, 1] / covs[:, 1]) + # ]).T # New means + # num_components = 3 + # mus = new_mus + # covs = new_covs + mix_weights = -mix_weights / np.max(np.abs(mix_weights)) + for i in range(num_components): + mu = mus[:, i] + cov = np.diag(covs[:, i]) + v, w = np.linalg.eigh(cov) + v = 2.0 * np.sqrt(2.0) * np.sqrt(v) + ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.8, fill=False) + ell_dot = mpl.patches.Circle(mu, radius=0.03, fill=True) + if isinstance(mixture, MonotonicPC): + ell.set_alpha(mix_weights[i]) + ell.set_color('red') + ell_dot.set_alpha(0.5 * mix_weights[i]) + ell_dot.set_color('red') + else: + if mix_weights[i] <= 0.0: + ell.set_alpha(min(1.0, 3 * np.abs(mix_weights[i]))) + #ell.set_color('#E53935') + ell.set_linestyle('dotted') + ell.set_color('red') + ell_dot.set_alpha(0.5 * np.abs(mix_weights[i])) + #ell_dot.set_color('#E53935') + ell_dot.set_color('red') + else: + ell.set_alpha(mix_weights[i]) + ell.set_color('red') + ell_dot.set_alpha(0.5 * mix_weights[i]) + ell_dot.set_color('red') + ax.add_artist(ell) + ax.add_artist(ell_dot) + + +def plot_pdf( + pdf: np.ndarray, + metadata: dict, + ax: plt.Axes, vmin: + Optional[float] = None, + vmax: Optional[float] = None +): + x_lim = metadata['domains'][0] + y_lim = metadata['domains'][1] + x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) + y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) + + x_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + y_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + + xi, yi = np.mgrid[range(pdf.shape[0]), range(pdf.shape[1])] + xi = (xi + 0.5) / pdf.shape[0] + yi = (yi + 0.5) / pdf.shape[1] + xi = xi * (x_lim[1] - x_lim[0]) + x_lim[0] + yi = yi * (y_lim[1] - y_lim[0]) + y_lim[0] + ax.pcolormesh(xi, yi, pdf, vmin=vmin, vmax=vmax) + + +if __name__ == '__main__': + args = parser.parse_args() + + models = [ + 'MonotonicPC', + 'MonotonicPC', + 'BornPC', + 'MAF', + 'NSF' + ] + + num_components = [2, 16, 2, 128, 128] + learning_rates = [5e-3, 5e-3, 4e-3, 1e-3, 1e-3] + + exp_id_formats = [ + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN', + 'K{}_OAdam_LR{}_BS{}', + 'K{}_OAdam_LR{}_BS{}' + ] + + mixtures = [ + load_mixture(m, eif, nc, lr) + for m, eif, nc, lr in zip(models[:3], exp_id_formats, num_components, learning_rates) + ] + [None, None] + + pdfs = [ + load_pdf(m, eif, nc, lr) + for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates) + ] + vmax = np.max(pdfs) + vmin = 0.0 + + metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) + + os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True) + for idx, (p, pdf, m, nc) in enumerate(zip(mixtures, pdfs, models, num_components)): + setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0) + fig, ax = plt.subplots(1, 1) + title = f"{format_model_name(m, nc)}" if args.title else None + + plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax) + if p is not None: + plot_mixture_ellipses(p, ax=ax) + + x_lim = metadata['domains'][0] + y_lim = metadata['domains'][1] + x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) + y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) + x_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + y_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + + ax.set_xlim(*x_lim) + ax.set_ylim(*y_lim) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect(1.0) + + if args.title: + ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center') + + plt.savefig(os.path.join('figures', 'gaussian-ring', f'pdfs-ellipses-{idx}.png'), dpi=1200) diff --git a/src/scripts/plots/ring/pdfs.py b/src/scripts/plots/ring/pdfs.py index a575c1a..f7e694d 100644 --- a/src/scripts/plots/ring/pdfs.py +++ b/src/scripts/plots/ring/pdfs.py @@ -14,6 +14,7 @@ description="PDFs plotter" ) parser.add_argument('--checkpoint-path', default='checkpoints', type=str, help="The checkpoints path") +parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title") def ring_kde() -> np.ndarray: @@ -41,14 +42,9 @@ def load_pdf(model: str, exp_id: str) -> np.ndarray: return np.load(filepath) -def plot_pdf(pdf: np.ndarray, ax: plt.Axes, vmin: Optional[float] = None, vmax: Optional[float] = None, title: Optional[str] = None, color: str = 'k'): +def plot_pdf(pdf: np.ndarray, ax: plt.Axes, vmin: Optional[float] = None, vmax: Optional[float] = None): xi, yi = np.mgrid[range(pdf.shape[0]), range(pdf.shape[1])] ax.pcolormesh(xi, yi, pdf, vmin=vmin, vmax=vmax) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_aspect(1.0) - if title is not None: - ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center', color=color) if __name__ == '__main__': @@ -74,7 +70,7 @@ def plot_pdf(pdf: np.ndarray, ax: plt.Axes, vmin: Optional[float] = None, vmax: for m, eid in zip(models[1:], exp_ids[1:]) ] pdfs = [truth_pdf] + pdfs - vmax = np.max(pdfs) + vmax = np.max(truth_pdf) vmin = 0.0 os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True) @@ -87,9 +83,22 @@ def plot_pdf(pdf: np.ndarray, ax: plt.Axes, vmin: Optional[float] = None, vmax: else: num_components = int(eid.split('_')[0][1:]) title = f"{format_model_name(m, num_components)}" - color = 'k' if idx == len(models) - 1 else 'k' else: title = m - color = 'k' - plot_pdf(p, vmin=vmin, vmax=vmax, ax=ax, title=title, color=color) - plt.savefig(os.path.join('figures', 'gaussian-ring', f'pdfs-{idx}.png'), dpi=1200) + + if idx == 0: + vmax = None + args.title = True + + plot_pdf(p, vmin=vmin, vmax=vmax, ax=ax) + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect(1.0) + if args.title: + ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center') + + if idx == 0: + plt.savefig(os.path.join('figures', 'gaussian-ring', f'pdfs-gt.png'), dpi=1200) + else: + plt.savefig(os.path.join('figures', 'gaussian-ring', f'pdfs-{idx}.png'), dpi=1200) diff --git a/src/scripts/tables/__init__.py b/src/scripts/tables/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/scripts/tables/uci/__init__.py b/src/scripts/tables/uci/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/scripts/utils.py b/src/scripts/utils.py index e03e693..a8deaf9 100644 --- a/src/scripts/utils.py +++ b/src/scripts/utils.py @@ -25,7 +25,7 @@ from pcs.layers.tucker import MonotonicTucker2Layer, BornTucker2Layer from pcs.layers.candecomp import MonotonicCPLayer, BornCPLayer from pcs.models import PC, MonotonicPC, BornPC -from graphics.distributions import plot_bivariate_samples_hmap, plot_bivariate_discrete_samples_hmap +from graphics.distributions import plot_bivariate_samples_hmap, plot_bivariate_discrete_samples_hmap, kde_samples_hmap from region_graph import RegionGraph, RegionNode from region_graph.linear_vtree import LinearVTree from region_graph.quad_tree import QuadTree @@ -164,22 +164,17 @@ def retrieve_tboard_runs(tboard_path: str, metrics: Union[str, List[str]], ignor # Throw out rows with no result for the metric for m in metrics: df_scalars = df_scalars[~pd.isna(df_scalars[m])] - - try: - n_diverged = np.sum(df_scalars['diverged']) - assert n_diverged == 0, "Some runs have diverged" - assert len(df_hparams) == len(df_scalars), "Number of runs and results is different" - except AssertionError as e: - if ignore_diverged: - print(f"Found {n_diverged} diverged runs. Ignoring...") - df_scalars = df_scalars[df_scalars['diverged'] == False] - else: - raise e + assert len(df_hparams) == len(df_scalars), "Number of runs and results is different" + if ignore_diverged: + n_diverged = int(np.sum(df_scalars['diverged'])) + print(f"Found {n_diverged} diverged runs. Ignoring...") + df_scalars = df_scalars[df_scalars['diverged'] == False] df = df_hparams.merge(df_scalars, on='dir_name', sort=True).drop('dir_name', axis=1) return df + def retrieve_tboard_df(tboard_path: str) -> pd.DataFrame: reader = SummaryReader(tboard_path, pivot=True, extra_columns={'dir_name'}) df_hparams = reader.hparams @@ -383,8 +378,7 @@ def setup_data_loaders( metadata['type'] = 'artificial' metadata['interval'] = (np.min(data_min), np.max(data_max)) metadata['domains'] = [(data_min[i], data_max[i]) for i in range(len(data_min))] - metadata['hmap'] = plot_bivariate_samples_hmap( - train_data, xlim=metadata['domains'][0], ylim=metadata['domains'][1]) + metadata['hmap'] = kde_samples_hmap(train_data, xlim=metadata['domains'][0], ylim=metadata['domains'][1]) train_dataloader = DataLoader(train_data, batch_size, num_workers=num_workers, shuffle=True) valid_dataloader = DataLoader(valid_data, batch_size, num_workers=num_workers) test_dataloader = DataLoader(test_data, batch_size, num_workers=num_workers) @@ -409,6 +403,7 @@ def setup_model( init_method: str = 'normal', init_scale: float = 1.0, dequantize: bool = False, + l2norm: bool = False, seed: int = 123 ) -> Union[PC, Flow]: if binomials and splines: @@ -483,13 +478,15 @@ def setup_model( in_mixture_layer_cls = BornMixtureLayer elif 'HMM' in model_name: model_cls = MonotonicHMM if 'Monotonic' in model_name else BornHMM + kwargs = dict() if 'Monotonic' in model_name else {'l2norm': l2norm} assert dataset_type == 'language' model = model_cls( vocab_size=interval[1] + 1, seq_length=num_variables, hidden_size=num_components, init_method=init_method, - init_scale=init_scale + init_scale=init_scale, + **kwargs ) return model elif model_name == 'NICE': @@ -563,6 +560,8 @@ def setup_model( if model_name == 'BornPC': if all(n not in input_layer_cls.__name__ for n in ['Normal', 'Binomial']): input_layer_kwargs['exp_reparam'] = exp_reparam + if 'Embeddings' in input_layer_cls.__name__: + input_layer_kwargs['l2norm'] = l2norm compute_layer_kwargs['exp_reparam'] = exp_reparam return model_cls( rg, diff --git a/src/tests/test_likelihood.py b/src/tests/test_likelihood.py index 5b3d1d1..c333ec4 100644 --- a/src/tests/test_likelihood.py +++ b/src/tests/test_likelihood.py @@ -22,6 +22,14 @@ from tests.test_utils import generate_all_binary_samples, generate_all_ternary_samples +def check_normalized_log_scores(model: PC, x: torch.Tensor) -> torch.Tensor: + scores = model.log_score(x) + assert scores.shape == (len(x), 1) + assert torch.all(torch.isfinite(scores)) + assert torch.allclose(torch.logsumexp(scores, dim=0).exp(), torch.tensor(1.0), atol=1e-15) + return scores + + def check_evi_ll(model: PC, x: torch.Tensor) -> torch.Tensor: lls = model.log_prob(x) assert lls.shape == (len(x), 1) @@ -95,6 +103,23 @@ def test_born_pc_random(compute_layer, num_variables, num_replicas, depth, num_c check_mar_ll_one(model, data, num_mar_variables=num_mar_variables) +@pytest.mark.parametrize("compute_layer,num_variables,num_replicas,num_components", + list(itertools.product( + [BornCPLayer], + [3, 8], [1, 4], [1, 2] + ))) +def test_born_pc_linear_stiefel(compute_layer, num_variables, num_replicas, num_components): + rg = LinearVTree(num_variables, num_repetitions=num_replicas) + compute_layer_kwargs = {'init_method': 'stiefel'} + input_layer_kwargs = {'init_method': 'stiefel', 'num_states': 3} + model = BornPC( + rg, input_layer_cls=BornEmbeddings, compute_layer_cls=compute_layer, + input_layer_kwargs=input_layer_kwargs, compute_layer_kwargs=compute_layer_kwargs, + num_components=num_components) + data = torch.LongTensor(generate_all_ternary_samples(num_variables)) + check_normalized_log_scores(model, data) + + @pytest.mark.parametrize("compute_layer,image_shape,num_components,input_mixture", list(itertools.product( [MonotonicCPLayer], @@ -139,17 +164,17 @@ def test_monotonic_pc_pseudo_large_image(compute_layer, image_shape, num_compone assert lls.shape == (len(data), 1) -@pytest.mark.parametrize("compute_layer,image_shape,num_components,input_mixture", +@pytest.mark.parametrize("compute_layer,image_shape,num_components,input_mixture,l2norm", list(itertools.product( [BornCPLayer], - [(1, 7, 7), (3, 28, 28)], [1, 3], [False, True] + [(1, 7, 7), (3, 28, 28)], [1, 3], [False, True], [False, True] ))) -def test_born_pc_pseudo_large_image(compute_layer, image_shape, num_components, input_mixture): +def test_born_pc_pseudo_large_image(compute_layer, image_shape, num_components, input_mixture, l2norm): rg = QuadTree(image_shape, struct_decomp=True) model = BornPC( rg, input_layer_cls=BornEmbeddings, compute_layer_cls=compute_layer, input_mixture=input_mixture, num_components=num_components, - input_layer_kwargs={'num_states': 768}) + input_layer_kwargs={'num_states': 768, 'l2norm': l2norm}) data = torch.round(torch.rand((42, np.prod(image_shape)))).long() lls = model.log_prob(data) assert lls.shape == (len(data), 1) @@ -339,9 +364,9 @@ def test_monotonic_hmm(seq_length, hidden_size): check_evi_ll(model, data) -@pytest.mark.parametrize("seq_length,hidden_size", - list(itertools.product([2, 7], [1, 13]))) -def test_born_hmm(seq_length, hidden_size): - model = BornHMM(vocab_size=3, seq_length=seq_length, hidden_size=hidden_size) +@pytest.mark.parametrize("seq_length,hidden_size,l2norm", + list(itertools.product([2, 7], [1, 13], [False, True]))) +def test_born_hmm(seq_length, hidden_size, l2norm): + model = BornHMM(vocab_size=3, seq_length=seq_length, hidden_size=hidden_size, l2norm=l2norm) data = torch.LongTensor(generate_all_ternary_samples(seq_length)) check_evi_ll(model, data)