Skip to content

Commit

Permalink
remove checkpoint loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Damowerko committed Mar 11, 2024
1 parent 47d3cf9 commit 267040f
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def train(trainer: pl.Trainer, params: argparse.Namespace):
train_loader = DataLoader(train_dataset, shuffle=True, **dataloader_kwargs)
val_loader = DataLoader(val_dataset, **dataloader_kwargs)

trainer.fit(model, train_loader, val_loader, ckpt_path=get_checkpoint_path())
trainer.fit(model, train_loader, val_loader)


def test(trainer: pl.Trainer, params: argparse.Namespace):
Expand All @@ -129,7 +129,7 @@ def test(trainer: pl.Trainer, params: argparse.Namespace):
)
trainer = make_trainer(params)
model = Conv2dCoder(**vars(params))
trainer.test(model, test_loader, ckpt_path=get_checkpoint_path())
trainer.test(model, test_loader)


def study(params: argparse.Namespace):
Expand Down Expand Up @@ -248,11 +248,6 @@ def make_trainer(params: argparse.Namespace, callbacks=[]) -> pl.Trainer:
)


def get_checkpoint_path() -> Union[str, None]:
ckpt_path = "./checkpoints/best.ckpt"
return ckpt_path if os.path.exists(ckpt_path) else None


def get_online_dataset(params: argparse.Namespace, n_experiments=1, n_steps=100):
return OnlineImageDataset(
**dict(
Expand Down

0 comments on commit 267040f

Please sign in to comment.