From fba34d9f0a9bee588ad1ed4b1489d1dd5893909b Mon Sep 17 00:00:00 2001 From: Jon Deaton Date: Fri, 1 Dec 2023 15:57:12 -0800 Subject: [PATCH] Updated to run hg38 experiment on newer pytorch lightning version. --- requirements.txt | 4 ++-- src/dataloaders/datasets/hg38_char_tokenizer.py | 3 +++ train.py | 8 -------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0d1b04f..32f1bc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ scikit-learn matplotlib tqdm rich -pytorch-lightning==1.8.6 +pytorch-lightning==2.1.2 hydra-core omegaconf wandb @@ -29,4 +29,4 @@ pyfaidx polars genomic-benchmarks loguru -liftover \ No newline at end of file +liftover diff --git a/src/dataloaders/datasets/hg38_char_tokenizer.py b/src/dataloaders/datasets/hg38_char_tokenizer.py index b60408e..e62dad2 100644 --- a/src/dataloaders/datasets/hg38_char_tokenizer.py +++ b/src/dataloaders/datasets/hg38_char_tokenizer.py @@ -71,6 +71,9 @@ def __init__(self, characters: Sequence[str], model_max_length: int, padding_sid def vocab_size(self) -> int: return len(self._vocab_str_to_int) + def get_vocab(self) -> dict[str, int]: + return self.get_added_vocab() + def _tokenize(self, text: str) -> List[str]: return list(text) diff --git a/train.py b/train.py index a052a99..c8699c8 100644 --- a/train.py +++ b/train.py @@ -364,19 +364,11 @@ def on_train_epoch_start(self): # Reset training torchmetrics self.task._reset_torchmetrics("train") - def training_epoch_end(self, outputs): - # Log training torchmetrics - super().training_epoch_end(outputs) - def on_validation_epoch_start(self): # Reset all validation torchmetrics for name in self.val_loader_names: self.task._reset_torchmetrics(name) - def validation_epoch_end(self, outputs): - # Log all validation torchmetrics - super().validation_epoch_end(outputs) - def on_test_epoch_start(self): # Reset all test torchmetrics for name in self.test_loader_names: