Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Dec 14, 2023
1 parent 9182247 commit c4cb0ab
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
11 changes: 5 additions & 6 deletions sleap_nn/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def train(self):
# logger to create csv with metrics values over the epochs
csv_logger = CSVLogger(dir_path)
self.logger.append(csv_logger)

# save the configs as yaml in the checkpoint dir
OmegaConf.save(config=self.config, f=dir_path + "config.yaml")

else:
Expand All @@ -143,9 +145,9 @@ def train(self):
)

trainer.fit(self.model, self.train_data_loader, self.val_data_loader)
# save the configs as yaml in the checkpoint dir

wandb.finish()
if self.config.trainer_config.use_wandb:
wandb.finish()


def xavier_init_weights(x):
Expand All @@ -164,7 +166,7 @@ class TopDownCenteredInstanceModel(L.LightningModule):
config: OmegaConf dictionary which has the following:
(i) data_config: data loading pre-processing configs to be passed to `TopdownConfmapsPipeline` class.
(ii) model_config: backbone and head configs to be passed to `Model` class.
(iii) trainer_cong: trainer configs like accelerator, optimiser params.
(iii) trainer_config: trainer configs like accelerator, optimiser params.
"""

Expand Down Expand Up @@ -220,7 +222,6 @@ def training_step(self, batch, batch_idx):
self.log(
"train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger=True
)
self.training_loss[self.current_epoch] = loss.detach().cpu().numpy()
return loss

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -248,8 +249,6 @@ def validation_step(self, batch, batch_idx):
on_epoch=True,
logger=True,
)
self.val_loss[self.current_epoch] = val_loss.detach().cpu().numpy()
self.learning_rate[self.current_epoch] = lr

def configure_optimizers(self):
"""Configure optimiser and learning rate scheduler."""
Expand Down
1 change: 0 additions & 1 deletion tests/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_create_data_loader(config, sleap_data_dir, tmp_path: str):
)

OmegaConf.update(config, "data_config.test", config_test)
print(config.data_config.test)
model_trainer = ModelTrainer(config)
model_trainer._create_data_loaders()
assert isinstance(
Expand Down

0 comments on commit c4cb0ab

Please sign in to comment.