Skip to content

Commit

Permalink
update flax and ipsuite tests
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Dec 14, 2024
1 parent 7071728 commit a5e382b
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 504 deletions.
5 changes: 3 additions & 2 deletions apax/train/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ def create_params(model, rng_key, sample_input: tuple, n_models: int):


def load_state(state, ckpt_dir):
ckpt_dir = Path(ckpt_dir)
start_epoch = 0
target = {"model": state, "epoch": 0}
checkpoints_exist = Path(ckpt_dir).is_dir()
checkpoints_exist = ckpt_dir.is_dir()
if checkpoints_exist:
log.info("Loading checkpoint")
raw_restored = checkpoints.restore_checkpoint(ckpt_dir, target=target, step=None)
raw_restored = checkpoints.restore_checkpoint(ckpt_dir.resolve(), target=target, step=None)
state = raw_restored["model"]
start_epoch = raw_restored["epoch"] + 1
log.info("Successfully restored checkpoint from epoch %d", raw_restored["epoch"])
Expand Down
Loading

0 comments on commit a5e382b

Please sign in to comment.