Skip to content

Commit

Permalink
fix decoding module docstrings & add static typing
Browse files Browse the repository at this point in the history
  • Loading branch information
kushaangupta committed Dec 27, 2024
1 parent d8e969c commit 8c55cec
Show file tree
Hide file tree
Showing 7 changed files with 1,236 additions and 397 deletions.
238 changes: 175 additions & 63 deletions neuro_py/ensemble/decoding/lstm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Dict, Tuple, Optional

import torch
import torch.nn.functional as F
import lightning as L
Expand All @@ -6,43 +8,60 @@


class LSTM(L.LightningModule):
"""Long Short-Term Memory (LSTM) model."""
def __init__(self, in_dim=100, out_dim=2, hidden_dims=(400, 1, .0), use_bias=True, args={}):
"""
Constructs a LSTM model
Parameters
----------
in_dim : int
Dimensionality of input data
out_dim : int
Dimensionality of output data
hidden_dims : List
Architectural parameters of the model
(hidden_size, num_layers, dropout)
use_bias : bool
Whether to use bias or not in the final linear layer
"""
"""
Long Short-Term Memory (LSTM) model.
This class implements an LSTM model using PyTorch Lightning.
Parameters
----------
in_dim : int, optional
Dimensionality of input data, by default 100
out_dim : int, optional
Dimensionality of output data, by default 2
hidden_dims : Tuple[int, int, float], optional
Architectural parameters of the model (hidden_size, num_layers, dropout),
by default (400, 1, 0.0)
use_bias : bool, optional
Whether to use bias or not in the final linear layer, by default True
args : Dict, optional
Additional arguments for model configuration, by default {}
Attributes
----------
lstm : nn.LSTM
LSTM layer
fc : nn.Linear
Fully connected layer
hidden_state : Optional[torch.Tensor]
Hidden state of the LSTM
cell_state : Optional[torch.Tensor]
Cell state of the LSTM
"""
def __init__(self, in_dim: int = 100, out_dim: int = 2,
hidden_dims: Tuple[int, int, float] = (400, 1, 0.0),
use_bias: bool = True, args: Dict = {}):
super().__init__()
self.save_hyperparameters()
self.in_dim = in_dim
self.out_dim = out_dim
if len(hidden_dims) != 3:
raise ValueError('`hidden_dims` should be of size 3')
hidden_size, nlayers, dropout = hidden_dims
self.nlayers = nlayers
self.hidden_size = hidden_size
self.dropout = dropout
self.hidden_size, self.nlayers, self.dropout = hidden_dims
self.args = args

# Add final layer to the number of classes
self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hidden_size,
num_layers=nlayers, batch_first=True, dropout=dropout, bidirectional=True)
self.fc = nn.Linear(in_features=2*hidden_size, out_features=out_dim, bias=use_bias)
self.hidden_state = None
self.cell_state = None
self.lstm = nn.LSTM(input_size=in_dim, hidden_size=self.hidden_size,
num_layers=self.nlayers, batch_first=True,
dropout=self.dropout, bidirectional=True)
self.fc = nn.Linear(in_features=2*self.hidden_size, out_features=out_dim, bias=use_bias)
self.hidden_state: Optional[torch.Tensor] = None
self.cell_state: Optional[torch.Tensor] = None

def init_params(m):
self._init_params()

def _init_params(self) -> None:
"""Initialize model parameters."""
def init_params(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
if m.bias is not None:
Expand All @@ -51,35 +70,63 @@ def init_params(m):
nn.init.uniform_(m.bias, -bound, bound) # LeCunn init
init_params(self.fc)

def forward(self, x):
lstm_out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
B, L, H = lstm_out.shape
# Shape: [batch_size x max_length x hidden_dim]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the LSTM model.
# Select the activation of the last Hidden Layer
# lstm_out = lstm_out.view(B, L, 2, -1).sum(dim=2)
lstm_out = lstm_out[:,-1,:].contiguous()

# Shape: [batch_size x hidden_dim]
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, input_dim)
# Fully connected layer
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, output_dim)
"""
lstm_out, (self.hidden_state, self.cell_state) = \
self.lstm(x, (self.hidden_state, self.cell_state))
lstm_out = lstm_out[:, -1, :].contiguous()
out = self.fc(lstm_out)
if self.args['clf']:
if self.args.get('clf', False):
out = F.log_softmax(out, dim=1)

return out

def init_hidden(self, batch_size):
''' Initializes hidden state '''
# Create two new tensors with sizes n_layers x batch_size x hidden_dim,
# initialized to zero, for hidden state and cell state of LSTM
def init_hidden(self, batch_size: int) -> None:
"""
Initialize hidden state and cell state.
Parameters
----------
batch_size : int
Batch size for initialization
"""
self.batch_size = batch_size
h0 = torch.zeros((2*self.nlayers,batch_size,self.hidden_size), requires_grad=False)
c0 = torch.zeros((2*self.nlayers,batch_size,self.hidden_size), requires_grad=False)
h0 = torch.zeros(
(2*self.nlayers, batch_size, self.hidden_size),
requires_grad=False
)
c0 = torch.zeros(
(2*self.nlayers, batch_size, self.hidden_size),
requires_grad=False
)
self.hidden_state = h0
self.cell_state = c0

def predict(self, x):
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Make predictions using the LSTM model.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Predicted output
"""
self.hidden_state = self.hidden_state.to(x.device)
self.cell_state = self.cell_state.to(x.device)
preds = []
Expand All @@ -93,45 +140,110 @@ def predict(self, x):
pred_loc = pred_loc[:batch_size-(i-x.shape[0])]
preds.extend(pred_loc)
out = torch.stack(preds)
if self.args['clf']:
if self.args.get('clf', False):
out = F.log_softmax(out, dim=1)
return out

def _step(self, batch, batch_idx) -> torch.Tensor:
xs, ys = batch # unpack the batch
outs = self(xs) # apply the model
def _step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""
Perform a single step (forward pass + loss calculation).
Parameters
----------
batch : Tuple[torch.Tensor, torch.Tensor]
Batch of input data and labels
batch_idx : int
Index of the current batch
Returns
-------
torch.Tensor
Computed loss
"""
xs, ys = batch
outs = self(xs)
loss = self.args['criterion'](outs, ys)
return loss

def training_step(self, batch, batch_idx) -> torch.Tensor:
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""
Lightning method for training step.
Parameters
----------
batch : Tuple[torch.Tensor, torch.Tensor]
Batch of input data and labels
batch_idx : int
Index of the current batch
Returns
-------
torch.Tensor
Computed loss
"""
loss = self._step(batch, batch_idx)
self.log('train_loss', loss)
return loss

def on_after_backward(self):
# LSTM specific
def on_after_backward(self) -> None:
"""Lightning method called after backpropagation."""
self.hidden_state.detach_()
self.cell_state.detach_()
# self.hidden_state.data.fill_(.0)
# self.cell_state.data.fill_(.0)

def validation_step(self, batch, batch_idx) -> torch.Tensor:
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""
Lightning method for validation step.
Parameters
----------
batch : Tuple[torch.Tensor, torch.Tensor]
Batch of input data and labels
batch_idx : int
Index of the current batch
Returns
-------
torch.Tensor
Computed loss
"""
loss = self._step(batch, batch_idx)
self.log('val_loss', loss)
return loss

def test_step(self, batch, batch_idx) -> torch.Tensor:
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""
Lightning method for test step.
Parameters
----------
batch : Tuple[torch.Tensor, torch.Tensor]
Batch of input data and labels
batch_idx : int
Index of the current batch
Returns
-------
torch.Tensor
Computed loss
"""
loss = self._step(batch, batch_idx)
self.log('test_loss', loss)
return loss

def configure_optimizers(self):
args = self.args
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Dict]]:
"""
Configure optimizers and learning rate schedulers.
Returns
-------
Tuple[List[torch.optim.Optimizer], List[Dict]]
Tuple containing a list of optimizers and a list of scheduler configurations
"""
optimizer = torch.optim.AdamW(
self.parameters(), weight_decay=args['weight_decay'])
self.parameters(), weight_decay=self.args['weight_decay'])
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=args['lr'],
epochs=args['epochs'],
optimizer, max_lr=self.args['lr'],
epochs=self.args['epochs'],
steps_per_epoch=len(
self.trainer._data_connector._train_dataloader_source.dataloader()
)
Expand Down
Loading

0 comments on commit 8c55cec

Please sign in to comment.