From 35f86d266f860f80867203e1bdceeef74006c6e7 Mon Sep 17 00:00:00 2001 From: Alexandre ANDRE Date: Thu, 13 Feb 2025 22:50:44 -0500 Subject: [PATCH] :recycle: --- torch_brain/models/ndt2.py | 63 +++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/torch_brain/models/ndt2.py b/torch_brain/models/ndt2.py index ac8001e6..f67204e9 100644 --- a/torch_brain/models/ndt2.py +++ b/torch_brain/models/ndt2.py @@ -22,31 +22,30 @@ def __init__( mask_ratio: float, dim, ctx_keys: List[str], - patch_size: Tuple[int, int], + units_per_patch: int, max_bincount: int, spike_pad: int, max_time_patches: int, max_space_patches: int, bin_time: float, - depth, - heads, - dropout, - ffn_mult, - causal=True, - activation="gelu", - pre_norm=False, + depth: int, + heads: int, + dropout: float, + ffn_mult: float, + causal: bool = True, + activation: str = "gelu", + pre_norm: bool = False, predictor_cfg: Dict = None, bhv_decoder_cfg: Dict = None, ): super().__init__() - # TODO should be changed for 1 int (we should only patch neurons not time) - spike_embed_dim = round(dim / patch_size[0]) + spike_embed_dim = round(dim / units_per_patch) self.bincount_emb = nn.Embedding(max_bincount, spike_embed_dim, padding_idx=pad) self.time_emb = nn.Embedding(max_time_patches, dim) self.space_emb = nn.Embedding(max_space_patches, dim) self.session_emb = InfiniteVocabEmbedding(dim) self.subject_emb = InfiniteVocabEmbedding(dim) - self.task_emb = InfiniteVocabEmbedding(dim) # TODO more about dataset than task + self.task_emb = InfiniteVocabEmbedding(dim) # more about dataset than task # Encoder enc_layer = nn.TransformerEncoderLayer( @@ -69,7 +68,7 @@ def __init__( dim=dim, max_time_patches=max_time_patches, max_space_patches=max_space_patches, - patch_size=patch_size, + patch_size=units_per_patch, **predictor_cfg, ) else: @@ -138,7 +137,7 @@ def __init__( self, bin_time: float, ctx_time: float, - patch_size: Tuple[int, int], + units_per_patch: int, pad_value: int, ctx_tokenizer: Dict[str, InfiniteVocabEmbedding], unsorted=True, @@ -151,7 +150,7 @@ def __init__( self.bin_time: float = bin_time self.ctx_time: float = ctx_time self.bin_size: int = int(np.round(ctx_time / bin_time)) - self.patch_size: Tuple[int, int] = patch_size # (num_neurons, num_time_bins) + self.units_per_patch: int = units_per_patch def float_modulo_test(x, y, eps=1e-6): return np.abs(x - y * np.round(x / y)) < eps @@ -183,29 +182,29 @@ def __call__(self, data: Data) -> Dict: binned_spikes = bin_spikes(data.spikes, num_units, self.bin_size) binned_spikes = np.clip(binned_spikes, 0, self.pad_value - 1) - num_spatial_patches = int(np.ceil(binned_spikes.shape[0] / self.patch_size[0])) - num_temporal_patches = int(np.ceil(binned_spikes.shape[1] / self.patch_size[1])) + nb_units = binned_spikes.shape[0] + num_spatial_patches = int(np.ceil(nb_units / self.units_per_patch)) + extra_units = num_spatial_patches * self.units_per_patch - nb_units - extra_units = num_spatial_patches * self.patch_size[0] - binned_spikes.shape[0] - # TODO should not be needed as we dont patch time - extra_time = num_temporal_patches * self.patch_size[1] - binned_spikes.shape[1] - - if extra_units > 0 or extra_time > 0: + if extra_units > 0: binned_spikes = np.pad( binned_spikes, - [(0, extra_units), (0, extra_time)], + [(0, extra_units)], mode="constant", constant_values=self.pad_value, ) + num_temporal_patches = binned_spikes.shape[1] + # major hack to have time before space, as in o.g. NDT2(nb_units, time_length) + # TODO could be mutch more cleaner binned_spikes = rearrange( binned_spikes, "(n pn) (t pt) -> (t n) pn pt", n=num_spatial_patches, t=num_temporal_patches, - pn=self.patch_size[0], - pt=self.patch_size[1], + pn=self.units_per_patch, + pt=1, ) # time and space indices for flattened patches @@ -246,14 +245,14 @@ def __call__(self, data: Data) -> Dict: ) shape = (num_temporal_patches, num_spatial_patches) - channel_counts = torch.full(shape, self.patch_size[0], dtype=torch.long) + units_count = torch.full(shape, self.units_per_patch, dtype=torch.long) + + # last patch may have fewer units if num_units % num_spatial_patches != 0: - channel_counts[:, -1] = self.patch_size[0] - extra_units - channel_counts = rearrange( - channel_counts, - "t n -> (t n)", - n=num_spatial_patches, - t=num_temporal_patches, + units_count[:, -1] = self.units_per_patch - extra_units + + units_count = rearrange( + units_count, "t n -> (t n)", n=num_spatial_patches, t=num_temporal_patches ) session_idx = self.session_tokenizer(data.session.id) @@ -265,7 +264,7 @@ def __call__(self, data: Data) -> Dict: "spike_tokens_mask": track_mask(spikes), "time_idx": pad(time_idx), "space_idx": pad(space_idx), - "channel_counts": pad(channel_counts), + "units_count": pad(units_count), "session_idx": session_idx, "subject_idx": subject_idx, "task_index": task_idx,