Skip to content

Commit

Permalink
[Feature] keep_checkpoints_num and checkpoint_at_end (facebookres…
Browse files Browse the repository at this point in the history
…earch#102)

* amend

* amend

* amend
  • Loading branch information
matteobettini authored Jun 18, 2024
1 parent 2a337e8 commit a930915
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
5 changes: 5 additions & 0 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,8 @@ restore_file: null
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
# Set it to 0 to disable checkpointing
checkpoint_interval: 0
# Wether to checkpoint when the experiment is done
checkpoint_at_end: False
# How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of
# checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints.
keep_checkpoints_num: 3
17 changes: 15 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import os
import time
from collections import OrderedDict
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
from pathlib import Path
from typing import Dict, List, Optional
Expand Down Expand Up @@ -96,7 +96,9 @@ class ExperimentConfig:

save_folder: Optional[str] = MISSING
restore_file: Optional[str] = MISSING
checkpoint_interval: float = MISSING
checkpoint_interval: int = MISSING
checkpoint_at_end: bool = MISSING
keep_checkpoints_num: Optional[int] = MISSING

def train_batch_size(self, on_policy: bool) -> int:
"""
Expand Down Expand Up @@ -280,6 +282,8 @@ def validate(self, on_policy: bool):
f"checkpoint_interval ({self.checkpoint_interval}) "
f"is not a multiple of the collected_frames_per_batch ({self.collected_frames_per_batch(on_policy)})"
)
if self.keep_checkpoints_num is not None and self.keep_checkpoints_num <= 0:
raise ValueError("keep_checkpoints_num must be greater than zero or null")
if self.max_n_frames is None and self.max_n_iters is None:
raise ValueError("n_iters and total_frames are both not set")

Expand Down Expand Up @@ -483,6 +487,7 @@ def _setup_name(self):
self.model_name = self.model_config.associated_class().__name__.lower()
self.environment_name = self.task.env_name().lower()
self.task_name = self.task.name.lower()
self._checkpointed_files = deque([])

if self.config.restore_file is not None and self.config.save_folder is not None:
raise ValueError(
Expand Down Expand Up @@ -668,6 +673,8 @@ def _collection_loop(self):
pbar.update()
sampling_start = time.time()

if self.config.checkpoint_at_end:
self._save_experiment()
self.close()

def close(self):
Expand Down Expand Up @@ -835,10 +842,16 @@ def load_state_dict(self, state_dict: Dict) -> None:

def _save_experiment(self) -> None:
"""Checkpoint trainer"""
if self.config.keep_checkpoints_num is not None:
while len(self._checkpointed_files) >= self.config.keep_checkpoints_num:
file_to_delete = self._checkpointed_files.popleft()
file_to_delete.unlink(missing_ok=False)

checkpoint_folder = self.folder_name / "checkpoints"
checkpoint_folder.mkdir(parents=False, exist_ok=True)
checkpoint_file = checkpoint_folder / f"checkpoint_{self.total_frames}.pt"
torch.save(self.state_dict(), checkpoint_file)
self._checkpointed_files.append(checkpoint_file)

def _load_experiment(self) -> Experiment:
"""Load trainer from checkpoint"""
Expand Down

0 comments on commit a930915

Please sign in to comment.