Skip to content

Commit

Permalink
made repo up to date
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Feb 27, 2024
1 parent 028fa8d commit 68d2e09
Show file tree
Hide file tree
Showing 22 changed files with 1,031 additions and 122 deletions.
2 changes: 1 addition & 1 deletion src/datasets/wrappers/gpt2_commongen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def load_gpt2_commongen(
path: str = 'datasets',
seed: int = 42,
num_splits: int = 20
num_splits: int = 40
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
splits = glob.glob(os.path.join(path, 'gpt2_commongen', 'common-gen.train.*'))
assert len(splits) > 0, f"There are no GPT2 CommonGen-tuned generated splits: {path} {splits}"
Expand Down
47 changes: 36 additions & 11 deletions src/graphics/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.neighbors import KernelDensity
from matplotlib import pyplot as plt

Expand All @@ -16,7 +17,7 @@ def plot_bivariate_samples_hmap(
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
zm: float = 0.0,
nbins: int = 600
nbins: int = 256
) -> np.ndarray:
setup_tueplots(1, 1, hw_ratio=1.0)
if xlim is None:
Expand Down Expand Up @@ -86,7 +87,7 @@ def kde_samples_hmap(
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
zm: float = 0.0,
nbins: int = 600,
nbins: int = 256,
*,
bandwidth: float = 0.2
) -> np.ndarray:
Expand All @@ -113,23 +114,47 @@ def bivariate_pdf_heatmap(
xlim: Tuple[float, float],
ylim: Tuple[float, float],
zm: float = 0.0,
nbins: int = 600,
device: Optional[Union[str, torch.device]] = None
nbins: int = 256,
batch_size: Optional[int] = None,
variables: Optional[Tuple[int, int]] = None,
device: Optional[Union[str, torch.device]] = None,
) -> np.ndarray:
if device is None:
device = 'cpu'
zm_xamount = np.abs(xlim[1] - xlim[0])
zm_yamount = np.abs(ylim[1] - ylim[0])
xlim = (xlim[0] - zm * zm_xamount), (xlim[1] + zm * zm_xamount)
ylim = (ylim[0] - zm * zm_yamount), (ylim[1] + zm * zm_yamount)
xi, yi = np.mgrid[xlim[0]:xlim[1]:nbins * 1j, ylim[0]:ylim[1]:nbins * 1j]
xy = np.stack([xi.flatten(), yi.flatten()], axis=1)\
xy = np.stack([xi.flatten(), yi.flatten()], axis=1) \
.astype(retrieve_default_dtype(numpy=True), copy=False)
if device is None:
device = 'cpu'
xy = torch.from_numpy(xy).to(device)
if isinstance(model, PC):
zi = model.log_prob(xy)
if model.num_variables < 1:
raise ValueError("A PC defined over two variables is required")
elif model.num_variables == 2:
xy = torch.from_numpy(xy).to(device)
if isinstance(model, PC):
zi = model.log_prob(xy)
else:
zi = model().log_prob(xy)
else:
zi = model().log_prob(xy)
if batch_size is None:
batch_size = 128
if variables is None or len(variables) != 2:
raise ValueError("The two variables to show the PDF of cannot be None")
# Perform variable marginalization
points = TensorDataset(xy[0], xy[1])
points = DataLoader(points, batch_size=batch_size, shuffle=False, drop_last=False)
mar_mask = torch.ones(1, model.num_variables, dtype=torch.bool, device=device)
mar_mask[:, variables] = False
zi = list()
for xf, yf in points:
samples = torch.zeros(size=(xf.shape[0], model.num_variables), dtype=xf.dtype)
samples[:, variables[0]] = xf
samples[:, variables[1]] = yf
samples.to(device)
log_probs = model.log_marginal_prob(samples, mar_mask)
zi.append(log_probs)
zi = torch.concatenate(zi, dim=0)
zi = torch.exp(zi).cpu().numpy().reshape(xi.shape)
return zi

Expand Down
12 changes: 10 additions & 2 deletions src/graphics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@ def setup_tueplots(
#})


def array_to_image(array: np.ndarray, vmin: Optional[float] = None, vmax: Optional[float] = None) -> np.ndarray:
def array_to_image(
array: np.ndarray,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
colorbar: bool = False
) -> np.ndarray:
assert len(array.shape) == 2
xi, yi = np.mgrid[range(array.shape[0]), range(array.shape[1])]
setup_tueplots(1, 1, hw_ratio=1.0)
fig, ax = plt.subplots()
ax.pcolormesh(xi, yi, array, shading='auto', vmin=vmin, vmax=vmax)
cmap = 'turbo' if colorbar else 'jet'
pcm = ax.pcolormesh(xi, yi, array, shading='auto', cmap=cmap, vmin=vmin, vmax=vmax)
if colorbar:
plt.colorbar(pcm)
ax.set_xticks([])
ax.set_yticks([])
return matplotlib_buffer_to_image(fig)
Expand Down
39 changes: 29 additions & 10 deletions src/pcs/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
from typing import Tuple, Union, Optional

import numpy as np
import torch
from torch import nn

Expand Down Expand Up @@ -95,12 +96,14 @@ def __init__(
seq_length: int,
hidden_size: int = 2,
init_method: str = 'normal',
init_scale: float = 1.0
init_scale: float = 1.0,
l2norm: bool = False
):
assert seq_length > 1
super().__init__(num_variables=seq_length)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.l2norm = l2norm

latent_prior = torch.empty(self.hidden_size)
init_params_(latent_prior, init_method, init_scale=init_scale)
Expand All @@ -119,63 +122,79 @@ def eval_log_pf(self) -> Tuple[Optional[Union[torch.Tensor, Tuple[torch.Tensor,
return None, log_pf

def _latent_prior(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
w = self.latent_prior
if self.l2norm:
w = w / torch.linalg.vector_norm(w, ord=2, dim=0, keepdim=True)

if square:
# x: (batch_size, hidden_size, hidden_size)
# self.latent_prior: (hidden_size,)
m_x, _ = torch.max(x, dim=2, keepdim=True) # (batch_size, hidden_size, 1)
x = x_si * torch.exp(x - m_x)
x = torch.sum(self.latent_prior * x, dim=2)
x = torch.einsum('bij,j->bi', x, w)
x_si = torch.sign(x.detach())
x = m_x.squeeze(dim=2) + safelog(torch.abs(x)) # (batch_size, hidden_size)
m_x, _ = torch.max(x, dim=1, keepdim=True) # (batch_size, 1)
x = x_si * torch.exp(x - m_x)
x = torch.sum(self.latent_prior * x, dim=1, keepdim=True)
x = torch.mm(x, w.unsqueeze(dim=1))
x_si = torch.sign(x.detach())
x = m_x + safelog(torch.abs(x)) # (batch_size, 1)
return x, x_si
# x: (batch_size, hidden_size)
# self.latent_prior: (hidden_size,)
m_x, _ = torch.max(x, dim=1, keepdim=True)
x = x_si * torch.exp(x - m_x)
y = torch.sum(self.latent_prior * x, dim=1, keepdim=True)
y = torch.mm(x, w.unsqueeze(dim=1))
y_si = torch.sign(y.detach())
y = safelog(torch.abs(y)) + m_x
return y, y_si

def _latent_conds(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
w = self.latent_conds
if self.l2norm:
w = w / torch.linalg.vector_norm(w, ord=2, dim=1, keepdim=True)

if square:
# x: (batch_size, hidden_size, hidden_size)
# self.latent_conds: (hidden_size, hidden_size)
m_x, _ = torch.max(x, dim=2, keepdim=True) # (batch_size, hidden_size, 1)
x = x_si * torch.exp(x - m_x)
x = torch.einsum('pi,bji->bpj', self.latent_conds, x)
x = torch.einsum('pi,bji->bpj', w, x)
x_si = torch.sign(x.detach())
x = m_x.permute(0, 2, 1) + safelog(torch.abs(x)) # (batch_size, hidden_size, hidden_size)
m_x, _ = torch.max(x, dim=2, keepdim=True) # (batch_size, hidden_size, 1)
x = x_si * torch.exp(x - m_x)
x = torch.einsum('qj,bpj->bpq', self.latent_conds, x)
x = torch.einsum('qj,bpj->bpq', w, x)
x_si = torch.sign(x.detach())
x = m_x + safelog(torch.abs(x)) # (batch_size, hidden_size, hidden_size)
return x, x_si
# x: (batch_size, hidden_size)
# self.latent_conds: (hidden_size, hidden_size)
m_x, _ = torch.max(x, dim=1, keepdim=True)
x = x_si * torch.exp(x - m_x)
y = torch.einsum('ij,bj->bi', self.latent_conds, x)
y = torch.einsum('ij,bj->bi', w, x)
y_si = torch.sign(y.detach())
y = safelog(torch.abs(y)) + m_x
return y, y_si

def _emission_conds(self, x: torch.Tensor, i: int) -> Tuple[torch.Tensor, torch.Tensor]:
e = self.emission_conds
if self.l2norm:
e = e / torch.linalg.vector_norm(e, ord=2, dim=0, keepdim=True)

zi = torch.arange(self.hidden_size, device=x.device).unsqueeze(dim=0)
w = self.emission_conds[zi, x[:, i].unsqueeze(dim=-1)]
w = e[zi, x[:, i].unsqueeze(dim=-1)]
w_si = torch.sign(w.detach())
w = safelog(torch.abs(w))
return w, w_si

def _emission_conds_normalize(self) -> Tuple[torch.Tensor, torch.Tensor]:
w_si = torch.sign(self.emission_conds.detach())
w = safelog(torch.abs(self.emission_conds))
e = self.emission_conds
if self.l2norm:
e = e / torch.linalg.vector_norm(e, ord=2, dim=0, keepdim=True)

w_si = torch.sign(e.detach())
w = safelog(torch.abs(e))
m_w, _ = torch.max(w, dim=1, keepdim=True)
e_w = w_si * torch.exp(w - m_w)
z = torch.mm(e_w, e_w.T)
Expand Down
12 changes: 9 additions & 3 deletions src/pcs/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
'gamma',
'dirichlet',
'xavier-uniform',
'xavier-normal'
'xavier-normal',
'stiefel'
]


Expand Down Expand Up @@ -55,11 +56,16 @@ def init_params_(tensor: torch.Tensor, method: str = 'normal', init_loc: float =
fan_in, fan_out = tensor.shape[-1], tensor.shape[-2]
std = np.sqrt(2.0 / float(fan_in + fan_out))
a = np.sqrt(3.0) * std
return nn.init.uniform_(tensor, -a, a)
nn.init.uniform_(tensor, -a, a)
elif method == 'xavier-normal':
fan_in, fan_out = tensor.shape[-1], tensor.shape[-2]
std = np.sqrt(2.0 / float(fan_in + fan_out))
return nn.init.normal_(tensor, 0.0, std)
nn.init.normal_(tensor, 0.0, std)
elif method == 'stiefel':
assert tensor.shape[-2] <= tensor.shape[-1]
nn.init.normal_(tensor, init_loc, init_scale)
q, _ = torch.linalg.qr(tensor.transpose(-2, -1), mode='reduced')
tensor.copy_(q.transpose(-2, -1))
else:
raise NotImplementedError(f"Unknown initialization method called {method}")

Expand Down
4 changes: 2 additions & 2 deletions src/pcs/layers/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def __init__(

class MonotonicComputeLayer(ComputeLayer, abc.ABC):
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
raise NotImplementedError()


class BornComputeLayer(ComputeLayer, abc.ABC):
def forward(self, x: torch.Tensor, x_si: torch.Tensor, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
pass
raise NotImplementedError()
21 changes: 17 additions & 4 deletions src/pcs/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ def __init__(
num_states: int = 2,
init_method: str = 'normal',
init_scale: float = 1.0,
exp_reparam: bool = False
exp_reparam: bool = False,
l2norm: bool = False
):
assert not exp_reparam or not l2norm, "Only one between --exp-reparam and --l2norm can be set true"
super().__init__(rg_nodes, num_components)
self.num_states = num_states
weight = torch.empty(self.num_variables, self.num_replicas, self.num_components, num_states)
Expand All @@ -94,10 +96,16 @@ def __init__(
weight = torch.log(weight)
self.weight = nn.Parameter(weight, requires_grad=True)
self.exp_reparam = exp_reparam
self.l2norm = l2norm
self._ohe = num_states <= 256

def log_pf(self) -> Tuple[torch.Tensor, torch.Tensor]:
weight = torch.exp(self.weight) if self.exp_reparam else self.weight
if self.exp_reparam:
weight = torch.exp(self.weight)
elif self.l2norm:
weight = self.weight / torch.linalg.vector_norm(self.weight, ord=2, dim=2, keepdim=True)
else:
weight = self.weight

w_si = torch.sign(weight.detach()) # (num_variables, num_replicas, num_components, num_states)
w = safelog(torch.abs(weight)) # (num_variables, num_replicas, num_components, num_states)
Expand All @@ -109,7 +117,12 @@ def log_pf(self) -> Tuple[torch.Tensor, torch.Tensor]:
return z.unsqueeze(dim=0), z_si.unsqueeze(dim=0)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = torch.exp(self.weight) if self.exp_reparam else self.weight
if self.exp_reparam:
weight = torch.exp(self.weight)
elif self.l2norm:
weight = self.weight / torch.linalg.vector_norm(self.weight, ord=2, dim=2, keepdim=True)
else:
weight = self.weight

# x: (-1, num_vars)
# self.weight: (num_vars, num_comps, num_states)
Expand All @@ -121,7 +134,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
w_si = torch.sign(w.detach())
w = safelog(torch.abs(w))
else:
weight = self.weight.permute(0, 3, 1, 2)
weight = weight.permute(0, 3, 1, 2)
w = weight[torch.arange(weight.shape[0], device=x.device), x]
w_si = torch.sign(w.detach())
w = safelog(torch.abs(w))
Expand Down
1 change: 0 additions & 1 deletion src/pcs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ def eval_log_pf(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:

def log_score(self, x: torch.Tensor) -> torch.Tensor:
(x, x_si), ldj = self._eval_input(x)

return self._eval_layers(x, x_si) + ldj

def log_marginal_score(
Expand Down
Loading

0 comments on commit 68d2e09

Please sign in to comment.