Skip to content

Commit

Permalink
♻️
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandre ANDRE committed Mar 5, 2025
1 parent 370044b commit 35f86d2
Showing 1 changed file with 31 additions and 32 deletions.
63 changes: 31 additions & 32 deletions torch_brain/models/ndt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,30 @@ def __init__(
mask_ratio: float,
dim,
ctx_keys: List[str],
patch_size: Tuple[int, int],
units_per_patch: int,
max_bincount: int,
spike_pad: int,
max_time_patches: int,
max_space_patches: int,
bin_time: float,
depth,
heads,
dropout,
ffn_mult,
causal=True,
activation="gelu",
pre_norm=False,
depth: int,
heads: int,
dropout: float,
ffn_mult: float,
causal: bool = True,
activation: str = "gelu",
pre_norm: bool = False,
predictor_cfg: Dict = None,
bhv_decoder_cfg: Dict = None,
):
super().__init__()
# TODO should be changed for 1 int (we should only patch neurons not time)
spike_embed_dim = round(dim / patch_size[0])
spike_embed_dim = round(dim / units_per_patch)
self.bincount_emb = nn.Embedding(max_bincount, spike_embed_dim, padding_idx=pad)
self.time_emb = nn.Embedding(max_time_patches, dim)
self.space_emb = nn.Embedding(max_space_patches, dim)
self.session_emb = InfiniteVocabEmbedding(dim)
self.subject_emb = InfiniteVocabEmbedding(dim)
self.task_emb = InfiniteVocabEmbedding(dim) # TODO more about dataset than task
self.task_emb = InfiniteVocabEmbedding(dim) # more about dataset than task

# Encoder
enc_layer = nn.TransformerEncoderLayer(
Expand All @@ -69,7 +68,7 @@ def __init__(
dim=dim,
max_time_patches=max_time_patches,
max_space_patches=max_space_patches,
patch_size=patch_size,
patch_size=units_per_patch,
**predictor_cfg,
)
else:
Expand Down Expand Up @@ -138,7 +137,7 @@ def __init__(
self,
bin_time: float,
ctx_time: float,
patch_size: Tuple[int, int],
units_per_patch: int,
pad_value: int,
ctx_tokenizer: Dict[str, InfiniteVocabEmbedding],
unsorted=True,
Expand All @@ -151,7 +150,7 @@ def __init__(
self.bin_time: float = bin_time
self.ctx_time: float = ctx_time
self.bin_size: int = int(np.round(ctx_time / bin_time))
self.patch_size: Tuple[int, int] = patch_size # (num_neurons, num_time_bins)
self.units_per_patch: int = units_per_patch

def float_modulo_test(x, y, eps=1e-6):
return np.abs(x - y * np.round(x / y)) < eps
Expand Down Expand Up @@ -183,29 +182,29 @@ def __call__(self, data: Data) -> Dict:
binned_spikes = bin_spikes(data.spikes, num_units, self.bin_size)
binned_spikes = np.clip(binned_spikes, 0, self.pad_value - 1)

num_spatial_patches = int(np.ceil(binned_spikes.shape[0] / self.patch_size[0]))
num_temporal_patches = int(np.ceil(binned_spikes.shape[1] / self.patch_size[1]))
nb_units = binned_spikes.shape[0]
num_spatial_patches = int(np.ceil(nb_units / self.units_per_patch))
extra_units = num_spatial_patches * self.units_per_patch - nb_units

extra_units = num_spatial_patches * self.patch_size[0] - binned_spikes.shape[0]
# TODO should not be needed as we dont patch time
extra_time = num_temporal_patches * self.patch_size[1] - binned_spikes.shape[1]

if extra_units > 0 or extra_time > 0:
if extra_units > 0:
binned_spikes = np.pad(
binned_spikes,
[(0, extra_units), (0, extra_time)],
[(0, extra_units)],
mode="constant",
constant_values=self.pad_value,
)

num_temporal_patches = binned_spikes.shape[1]

# major hack to have time before space, as in o.g. NDT2(nb_units, time_length)
# TODO could be mutch more cleaner
binned_spikes = rearrange(
binned_spikes,
"(n pn) (t pt) -> (t n) pn pt",
n=num_spatial_patches,
t=num_temporal_patches,
pn=self.patch_size[0],
pt=self.patch_size[1],
pn=self.units_per_patch,
pt=1,
)

# time and space indices for flattened patches
Expand Down Expand Up @@ -246,14 +245,14 @@ def __call__(self, data: Data) -> Dict:
)

shape = (num_temporal_patches, num_spatial_patches)
channel_counts = torch.full(shape, self.patch_size[0], dtype=torch.long)
units_count = torch.full(shape, self.units_per_patch, dtype=torch.long)

# last patch may have fewer units
if num_units % num_spatial_patches != 0:
channel_counts[:, -1] = self.patch_size[0] - extra_units
channel_counts = rearrange(
channel_counts,
"t n -> (t n)",
n=num_spatial_patches,
t=num_temporal_patches,
units_count[:, -1] = self.units_per_patch - extra_units

units_count = rearrange(
units_count, "t n -> (t n)", n=num_spatial_patches, t=num_temporal_patches
)

session_idx = self.session_tokenizer(data.session.id)
Expand All @@ -265,7 +264,7 @@ def __call__(self, data: Data) -> Dict:
"spike_tokens_mask": track_mask(spikes),
"time_idx": pad(time_idx),
"space_idx": pad(space_idx),
"channel_counts": pad(channel_counts),
"units_count": pad(units_count),
"session_idx": session_idx,
"subject_idx": subject_idx,
"task_index": task_idx,
Expand Down

0 comments on commit 35f86d2

Please sign in to comment.