Skip to content

Commit

Permalink
Merge branch 'nl_fix' of https://github.com/apax-hub/apax into nl_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 5, 2024
2 parents 46a8a65 + 6622207 commit b8eaec7
Show file tree
Hide file tree
Showing 9 changed files with 1,469 additions and 511 deletions.
1 change: 0 additions & 1 deletion apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,3 @@ checkpoints:

progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false
15 changes: 15 additions & 0 deletions apax/config/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import MutableMapping
from typing import Union

import yaml
Expand Down Expand Up @@ -28,3 +29,17 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
config = MDConfig.model_validate(config)

return config


def flatten(dictionary, parent_key="", separator="_"):
"""https://stackoverflow.com/questions/6027558/
flatten-nested-dictionaries-compressing-keys
"""
items = []
for key, value in dictionary.items():
new_key = parent_key + separator + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(flatten(value, new_key, separator=separator).items())
else:
items.append((new_key, value))
return dict(items)
49 changes: 41 additions & 8 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import logging
import os
from pathlib import Path
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union

import yaml
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeFloat,
PositiveFloat,
PositiveInt,
create_model,
model_validator,
)
from typing_extensions import Annotated

from apax.data.statistics import scale_method_list, shift_method_list

Expand Down Expand Up @@ -235,16 +237,47 @@ class LossConfig(BaseModel, extra="forbid"):
parameters: dict = {}


class CallbackConfig(BaseModel, frozen=True, extra="forbid"):
class CSVCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the training callbacks.
Configuration of the CSVCallback.
Parameters
----------
name: Keyword of the callback used. Currently we implement "csv" and "tensorboard".
name: Keyword of the callback used..
"""

name: str
name: Literal["csv"]


class TBCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the TensorBoard callback.
Parameters
----------
name: Keyword of the callback used..
"""

name: Literal["tensorboard"]


class MLFlowCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the MLFlow callback.
Parameters
----------
name: Keyword of the callback used.
experiment: Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>
"""

name: Literal["mlflow"]
experiment: str


CallBack = Annotated[
Union[CSVCallback, TBCallback, MLFlowCallback], Field(discriminator="name")
]


class TrainProgressbarConfig(BaseModel, extra="forbid"):
Expand All @@ -254,11 +287,11 @@ class TrainProgressbarConfig(BaseModel, extra="forbid"):
Parameters
----------
disable_epoch_pbar: Set to True to disable the epoch progress bar.
disable_nl_pbar: Set to True to disable the NL precomputation progress bar.
disable_batch_pbar: Set to True to disable the batch progress bar.
"""

disable_epoch_pbar: bool = False
disable_nl_pbar: bool = False
disable_batch_pbar: bool = True


class CheckpointConfig(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -311,7 +344,7 @@ class Config(BaseModel, frozen=True, extra="forbid"):
metrics: List[MetricsConfig] = []
loss: List[LossConfig]
optimizer: OptimizerConfig = OptimizerConfig()
callbacks: List[CallbackConfig] = [CallbackConfig(name="csv")]
callbacks: List[CallBack] = [CSVCallback(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()

Expand Down
53 changes: 33 additions & 20 deletions apax/train/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import logging
from pathlib import Path

import tensorflow as tf
from keras.callbacks import CSVLogger, TensorBoard

from apax.config.common import flatten
from apax.config.train_config import Config

log = logging.getLogger(__name__)


def initialize_callbacks(callback_configs, model_version_path):
def initialize_callbacks(config: Config, model_version_path: Path):
callback_configs = config.callbacks
log.info("Initializing Callbacks")

dummy_model = tf.keras.Model()
Expand All @@ -26,33 +31,41 @@ def initialize_callbacks(callback_configs, model_version_path):
"model": dummy_model,
},
}

callback_configs = [config.name for config in callback_configs]
if "csv" in callback_configs and "tensorboard" in callback_configs:
csv_idx, tb_idx = callback_configs.index("csv"), callback_configs.index(
"tensorboard"
)
names = [conf.name for conf in callback_configs]
if "csv" in names and "tensorboard" in names:
msg = (
"Using both csv and tensorboard callbacks is not supported at the moment."
" Only the first of the two will be used."
" Rerun training with only one of the two."
)
print("Warning: " + msg)
log.warning(msg)
if csv_idx < tb_idx:
callback_configs.pop(tb_idx)
else:
callback_configs.pop(csv_idx)
raise ValueError(msg)

callbacks = []
for callback_config in callback_configs:
callback_info = callback_dict[callback_config]
if callback_config.name == "mlflow":
try:
import mlflow
from mlflow.tensorflow import MLflowCallback
except ImportError:
log.warning("Make sure MLFlow is installed correctly")
mlflow.login()
mlflow.tensorflow.autolog()
experiment = callback_config.experiment
mlflow.set_experiment(experiment)

params = config.model_dump()
params = flatten(params)
mlflow.log_params(params)
callback = MLflowCallback()
callback.set_model(dummy_model)
else:
callback_info = callback_dict[callback_config.name]

path_arg_name = callback_info["path_arg_name"]
path = {path_arg_name: callback_info["log_path"]}
path_arg_name = callback_info["path_arg_name"]
path = {path_arg_name: callback_info["log_path"]}

kwargs = callback_info["kwargs"]
callback = callback_info["class"](**path, **kwargs)
callback.set_model(callback_info["model"])
kwargs = callback_info["kwargs"]
callback = callback_info["class"](**path, **kwargs)
callback.set_model(callback_info["model"])
callbacks.append(callback)

return tf.keras.callbacks.CallbackList([callback])
6 changes: 4 additions & 2 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def run(user_config, log_level="error"):
seed_py_np_tf(config.seed)
rng_key = jax.random.PRNGKey(config.seed)

log.info("Initializing directories")
config.data.model_version_path.mkdir(parents=True, exist_ok=True)
setup_logging(config.data.model_version_path / "train.log", log_level)
config.dump_config(config.data.model_version_path)
log.info(f"Running on {jax.devices()}")

callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path)
callbacks = initialize_callbacks(config, config.data.model_version_path)
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

Expand Down Expand Up @@ -148,5 +148,7 @@ def run(user_config, log_level="error"):
sam_rho=config.optimizer.sam_rho,
patience=config.patience,
disable_pbar=config.progress_bar.disable_epoch_pbar,
disable_batch_pbar=config.progress_bar.disable_batch_pbar,
is_ensemble=config.n_models > 1,
)
log.info("Finished training")
24 changes: 24 additions & 0 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fit(
sam_rho=0.0,
patience: Optional[int] = None,
disable_pbar: bool = False,
disable_batch_pbar: bool = True,
is_ensemble=False,
):
log.info("Beginning Training")
Expand Down Expand Up @@ -70,6 +71,16 @@ def fit(
epoch_loss.update({"train_loss": 0.0})
train_batch_metrics = Metrics.empty()

batch_pbar = trange(
0,
train_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)

for batch_idx in range(train_steps_per_epoch):
callbacks.on_train_batch_begin(batch=batch_idx)

Expand All @@ -84,6 +95,8 @@ def fit(

epoch_loss["train_loss"] += jnp.mean(batch_loss)
callbacks.on_train_batch_end(batch=batch_idx)
batch_pbar.update()

epoch_loss["train_loss"] /= train_steps_per_epoch
epoch_loss["train_loss"] = float(epoch_loss["train_loss"])

Expand All @@ -95,13 +108,24 @@ def fit(
if val_ds is not None:
epoch_loss.update({"val_loss": 0.0})
val_batch_metrics = Metrics.empty()

batch_pbar = trange(
0,
val_steps_per_epoch,
desc="Batches",
ncols=100,
mininterval=1.0,
disable=disable_batch_pbar,
leave=False,
)
for batch_idx in range(val_steps_per_epoch):
batch = next(batch_val_ds)

batch_loss, val_batch_metrics = val_step(
state.params, batch, val_batch_metrics
)
epoch_loss["val_loss"] += batch_loss
batch_pbar.update()

epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])
Expand Down
Loading

0 comments on commit b8eaec7

Please sign in to comment.