Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pretrain with mlflow #563

Merged
merged 26 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions psycop/common/sequence_models/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
The config Schema for sequence models.
"""

from collections.abc import Sequence
from pathlib import Path
from typing import Any, Optional, Union

from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.loggers import Logger as plLogger
from pydantic import BaseModel

from psycop.common.sequence_models.tasks import (
Expand All @@ -32,7 +33,7 @@ class Config:
num_nodes: int = 1
callbacks: list[Callback] = []
precision: str = "32-true"
logger: Optional[WandbLogger] = None
logger: Optional[plLogger] = None
max_epochs: Optional[int] = None
min_epochs: Optional[int] = None
max_steps: int = 10
Expand Down Expand Up @@ -100,3 +101,4 @@ class Config:
# Required because dataset and model are coupled through their input and outputs
model_and_dataset: PretrainingModelAndDataset | ClassificationModelAndDataset
training: TrainingConfigSchema
logger: Sequence[plLogger] | None = None
50 changes: 29 additions & 21 deletions psycop/common/sequence_models/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from pathlib import Path
from typing import Optional

from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.loggers import Logger as plLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger as plMLFlowLogger
from lightning.pytorch.loggers.wandb import WandbLogger as plWandbLogger

from .registry import Registry

Expand All @@ -20,28 +22,34 @@ def handle_wandb_folder():

@Registry.loggers.register("wandb")
def create_wandb_logger(
name: Optional[str] = None,
save_dir: Path | str = ".",
version: Optional[str] = None,
offline: bool = False,
dir: Optional[Path] = None, # noqa: A002
id: Optional[str] = None, # noqa: A002
anonymous: Optional[bool] = None,
project: Optional[str] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
) -> WandbLogger:
save_dir: Path | str,
experiment_name: str,
offline: bool,
run_name: Optional[str] = None,
) -> plLogger:
MartinBernstorff marked this conversation as resolved.
Show resolved Hide resolved
handle_wandb_folder()

return WandbLogger(
name=name,
return plWandbLogger(
name=run_name,
save_dir=save_dir,
version=version,
offline=offline,
dir=dir,
id=id,
anonymous=anonymous,
project=project,
prefix=prefix,
checkpoint_name=checkpoint_name,
project=experiment_name,
)


@Registry.loggers.register("mlflow")
def create_mlflow_logger(
save_dir: Path | str,
experiment_name: str,
offline: bool = False,
run_name: Optional[str] = None,
) -> plLogger:
MartinBernstorff marked this conversation as resolved.
Show resolved Hide resolved
if offline:
raise NotImplementedError("MLFlow does not support offline mode")

return plMLFlowLogger(
save_dir=str(save_dir),
experiment_name=experiment_name,
run_name=run_name,
tracking_uri="http://exrhel0371.it.rm.dk:5050",
)
12 changes: 12 additions & 0 deletions psycop/common/sequence_models/registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TypeVar

import catalogue
from confection import registry

Expand All @@ -13,3 +15,13 @@ class Registry(registry):
optimizers = catalogue.create("psycop", "optimizers")
lr_schedulers = catalogue.create("psycop", "lr_schedulers")
callbacks = catalogue.create("psycop", "callbacks")

utilities = catalogue.create("psycop", "utilities")


T = TypeVar("T")


@Registry.utilities.register("list_creator")
def list_creator(*args: T) -> list[T]:
return list(args)
19 changes: 7 additions & 12 deletions psycop/common/sequence_models/tests/test_config.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@



[training]
batch_size=64
num_workers_for_dataloader=2
Expand Down Expand Up @@ -42,24 +39,22 @@ monitor = "val_loss"
save_top_k = 5
every_n_epochs = 1
mode = "min"
save_dir = ${training.trainer.logger.save_dir}
save_dir = ${logger.*.wandb.save_dir}

[training.trainer.callbacks.*.learning_rate_monitor]
@callbacks = "learning_rate_monitor"
logging_interval = "epoch"


[training.trainer.logger]
[logger]
@utilities = "list_creator"

[logger.*.wandb]
@loggers = "wandb"
name = null
save_dir = "logs/"
version = null
experiment_name = "test_experiment"
run_name = "test_run"
offline = true
dir = null
id = null
anonymous = null
project = "psycop"
checkpoint_name = null

[model_and_dataset]
[model_and_dataset.model]
Expand Down
28 changes: 14 additions & 14 deletions psycop/common/sequence_models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .config_utils import load_config, parse_config

std_logger = logging.getLogger(__name__)
log = logging.getLogger(__name__)
os.environ["WANDB__SERVICE_WAIT"] = "300" # to avoid issues with wandb service


Expand All @@ -26,7 +26,7 @@ def populate_registry() -> None:
"""
from .callbacks import create_learning_rate_monitor, create_model_checkpoint # noqa
from .embedders.BEHRT_embedders import create_behrt_embedder # noqa
from .logger import create_wandb_logger # noqa
from .logger import create_mlflow_logger, create_wandb_logger # noqa
from .model_layers import create_encoder_layer, create_transformers_encoder # noqa
from .optimizers import create_adam # noqa
from .optimizers import create_adamw # noqa
Expand All @@ -47,27 +47,27 @@ def train(config_path: Path | None = None) -> None:
config_dict = load_config(config_path)
config = parse_config(config_dict)

# Setup the logger and pass it to the TrainingConfig
training_cfg = config.training
if config.logger is not None:
for logger in config.logger:
# update config
log.info("Updating Config")
flat_config = flatten_nested_dict(config_dict)
logger.log_hyperparams(flat_config)

training_dataset = config.model_and_dataset.training_dataset
validation_dataset = config.model_and_dataset.validation_dataset
model = config.model_and_dataset.model
logger = training_cfg.trainer.logger
trainer_kwargs = training_cfg.trainer.to_dict()

# update config
std_logger.info("Updating Config")
flat_config = flatten_nested_dict(config_dict)

if logger:
logger.experiment.config.update(flat_config)

# filter dataset
std_logger.info("Filtering Patients")
log.info("Filtering Patients")
filter_fn = model.filter_and_reformat
training_dataset.filter_patients(filter_fn)
validation_dataset.filter_patients(filter_fn)

std_logger.info("Creating dataloaders")
log.info("Creating dataloaders")
train_loader = DataLoader(
training_dataset,
batch_size=training_cfg.batch_size,
Expand All @@ -85,9 +85,9 @@ def train(config_path: Path | None = None) -> None:
persistent_workers=True,
)

std_logger.info("Initalizing trainer")
log.info("Initalizing trainer")
trainer = pl.Trainer(**trainer_kwargs)

std_logger.info("Starting training")
log.info("Starting training")
torch.set_float32_matmul_precision("medium")
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
22 changes: 9 additions & 13 deletions psycop/projects/sequence_models/pretrain_behrt.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ accumulate_grad_batches = 1
gradient_clip_val = null
gradient_clip_algorithm = null
default_root_dir = "logs/"
logger=${logger.mlflow}

[training.trainer.callbacks]
@callbacks = "callback_list"
Expand All @@ -38,27 +39,22 @@ monitor = "val_loss"
save_top_k = 2
every_n_epochs = 1
mode = "min"
save_dir = ${training.trainer.logger.save_dir}
save_dir = ${logger.mlflow.save_dir}

[training.trainer.callbacks.*.learning_rate_monitor]
@callbacks = "learning_rate_monitor"
logging_interval = "epoch"

[logger]
@utilities = "list_creator"

[training.trainer.logger]
@loggers = "wandb"
name = null
save_dir = "logs/"
version = null
offline = true
dir = null
id = null
anonymous = null
project = "psycop-sequence-models"
checkpoint_name = null
[logger.*.mlflow]
@loggers = "mlflow"
experiment_name = "pretrain_behrt"
run_name = "pretrain-2023-12-15"
save_dir = "E:/shared_resources/sequence_models/BEHRT/pretrain_behrt/pretrain-2023-12-15/"

[model_and_dataset]

[model_and_dataset.model]
@tasks = "behrt"

Expand Down
17 changes: 7 additions & 10 deletions psycop/projects/sequence_models/pretrain_mini_behrt.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,21 @@ monitor = "val_loss"
save_top_k = 5
every_n_epochs = 1
mode = "min"
save_dir = ${training.trainer.logger.save_dir}
save_dir = ${logger.mlflow.save_dir}

[training.trainer.callbacks.*.learning_rate_monitor]
@callbacks = "learning_rate_monitor"
logging_interval = "epoch"


[training.trainer.logger]
[logger]
@utilities = "list_creator"

[logger.*.mlflow]
@loggers = "wandb"
name = null
experiment_name = "pretrain_mini_behrt"
run_name = "mini_behrt"
save_dir = "logs/"
version = null
offline = true
dir = null
id = null
anonymous = null
project = "psycop-sequence-models"
checkpoint_name = null

[model_and_dataset]
[model_and_dataset.model]
Expand Down