Skip to content

Commit

Permalink
integrate KNN
Browse files Browse the repository at this point in the history
  • Loading branch information
Damowerko committed Mar 6, 2024
1 parent 07f9040 commit aa365bb
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 30 deletions.
5 changes: 4 additions & 1 deletion scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import typing
from functools import partial
from glob import glob
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -55,6 +54,8 @@ def main():
group = parser.add_argument_group("Trainer")
group.add_argument("--max_epochs", type=int, default=1000)
group.add_argument("--patience", type=int, default=10)
group.add_argument("--profiler", type=str, default=None)
group.add_argument("--fast_dev_run", action="store_true")

params = parser.parse_args()
if params.operation == "train":
Expand Down Expand Up @@ -242,6 +243,8 @@ def make_trainer(params: argparse.Namespace, callbacks=[]) -> pl.Trainer:
devices=1,
max_epochs=params.max_epochs,
default_root_dir=".",
profiler=params.profiler,
fast_dev_run=params.fast_dev_run,
)


Expand Down
41 changes: 23 additions & 18 deletions src/mtt/models/kernel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import pytorch_lightning as pl
import torch
import torchcps.kernel.nn as knn
Expand All @@ -16,33 +18,36 @@ def __init__(
measurement_dim: int,
state_dim: int,
pos_dim: int,
n_weights: int = 32,
n_channels: int = 32,
n_layers: int = 2,
sigma: float = 10.0,
max_filter_kernels: int = 100,
update_positions: bool = False,
hidden_channels: int = 32,
n_layers: int = 4,
n_layers_mlp: int = 2,
hidden_channels_mlp: int = 128,
sigma: float | typing.Sequence[float] = 1.0,
max_filter_kernels: int = 32,
update_positions: bool = True,
alpha: float | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters()
self.model = knn.KNN(
in_weights=measurement_dim,
out_weights=state_dim + 1,
n_weights=n_weights,
pos_dim,
in_channels=measurement_dim,
hidden_channels=hidden_channels,
out_channels=state_dim + 1,
n_layers=n_layers,
n_channels=n_channels,
n_layers_mlp=n_layers_mlp,
hidden_channels_mlp=hidden_channels_mlp,
sigma=sigma,
max_filter_kernels=max_filter_kernels,
update_positions=update_positions,
alpha=alpha,
)

def forward(
self, x: torch.Tensor, x_pos: torch.Tensor, x_batch: torch.Tensor
) -> SparseOutput:
x_mixture = knn.Mixture(x_pos, x)
y_mixture = self.model.forward(x_mixture, x_batch)
def _forward(self, x: torch.Tensor, x_pos: torch.Tensor, x_batch: torch.Tensor):
x_mixture = knn.Mixture(x_pos, x, x_batch)
y_mixture = self.model.forward(x_mixture)
mu = y_mixture.positions
sigma = y_mixture.weights[:-1]
logp = y_mixture.weights[-1]
return SparseOutput(mu=mu, sigma=sigma, logp=logp)
sigma = y_mixture.weights[:, :-1]
logits = y_mixture.weights[:, -1]
return mu, sigma, logits
29 changes: 24 additions & 5 deletions src/mtt/models/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment

from mtt.data.sparse import SparseData
Expand Down Expand Up @@ -73,10 +74,30 @@ def to_stlabel(self, data: SparseData) -> SparseLabel:
y_batch = data.target_batch_sizes
return SparseLabel(y, y_batch)

@abstractmethod
def forward(
self, x: torch.Tensor, x_pos: torch.Tensor, x_batch: torch.Tensor
) -> SparseOutput:
# apply activation to the output of _forward
mu, sigma, logits = self._forward(x, x_pos, x_batch)
# Sigma must be > 0.0 for torch.distributions.Normal
sigma = F.softplus(sigma) + 1e-16
logp = F.logsigmoid(logits)
return SparseOutput(mu, sigma, logp)

@abstractmethod
def _forward(
self, x: torch.Tensor, x_pos: torch.Tensor, x_batch: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Partial implementation of the forward pass.
Should have no activation function on the output.
The output of this function is used in `self.forward`.
Returns:
mu: (N, d) Predicted states.
sigma: (N, d) Covariance of the states.
logits: (N,) Existence probabilities in logit space.
"""
raise NotImplementedError()

def logp(
Expand Down Expand Up @@ -120,10 +141,8 @@ def logp(
for batch_idx in range(batch_size):
# find a matching between mu_i and y_j
with torch.no_grad():
match_cost = (
torch.cdist(mu_split[batch_idx], y_split[batch_idx], p=2)
- logp_split[batch_idx][:, None]
)
dist = torch.cdist(mu_split[batch_idx], y_split[batch_idx], p=2)
match_cost = dist - logp_split[batch_idx][:, None]
future = e.submit(linear_sum_assignment, match_cost.cpu().numpy())
futures[future] = batch_idx

Expand Down
11 changes: 5 additions & 6 deletions src/mtt/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,12 @@ def __init__(
n_channels, n_decoder, pos_dim, heads, dropout
)

def forward(
def _forward(
self,
x: torch.Tensor,
x_pos: torch.Tensor,
x_batch: Optional[torch.Tensor] = None,
) -> STOutput:
):
"""
Args:
x: (N, in_channels) Non-position measurement data.
Expand Down Expand Up @@ -368,7 +368,6 @@ def forward(
object = self.readout.forward(object, x_batch)
# Split into existence probability and state
mu = object[..., : self.state_dim]
# Sigma must be > 0.0 for torch.distributions.Normal
sigma = object[..., self.state_dim : 2 * self.state_dim].abs().clamp(min=1e-16)
logp = F.logsigmoid(object[..., -1])
return STOutput(mu, sigma, logp)
sigma = object[..., self.state_dim : 2 * self.state_dim]
logits = object[..., -1]
return mu, sigma, logits

0 comments on commit aa365bb

Please sign in to comment.