Skip to content

Commit

Permalink
Fix snapshotter ordering (#2327)
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner authored May 20, 2022
1 parent 6461a07 commit 3492f44
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/garage/experiment/snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def load(self, load_dir, itr='last'):
raise FileNotFoundError(errno.ENOENT,
os.strerror(errno.ENOENT),
'*.pkl file in', load_dir)
files.sort()
files.sort(key=_extract_snapshot_itr)
load_from_file = files[0] if itr == 'first' else files[-1]
load_from_file = os.path.join(load_dir, load_from_file)

Expand All @@ -170,5 +170,20 @@ def load(self, load_dir, itr='last'):
return cloudpickle.load(file)


def _extract_snapshot_itr(filename: str) -> int:
"""Extracts the integer itr from a filename.
Args:
filename(str): The snapshot filename.
Returns:
int: The snapshot as an integer.
"""
base = os.path.splitext(filename)[0]
digits = base.split('itr_')[1]
return int(digits)


class NotAFileError(Exception):
"""Raise when the snapshot is not a file."""
8 changes: 8 additions & 0 deletions tests/garage/experiment/test_snapshotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class TestSnapshotter:

def setup_method(self):
# pylint: disable=consider-using-with
self.temp_dir = tempfile.TemporaryDirectory()

def teardown_method(self):
Expand Down Expand Up @@ -78,3 +79,10 @@ def test_conflicting_params(self):
Snapshotter(snapshot_dir=self.temp_dir.name,
snapshot_mode='gap_overwrite',
snapshot_gap=1)

def test_sorts_correctly(self):
snapshotter = Snapshotter(self.temp_dir.name, 'all', 2)
snapshotter.save_itr_params(80, {'test_itr': 80})
snapshotter.save_itr_params(120, {'test_itr': 120})
last = snapshotter.load(self.temp_dir.name)
assert last['test_itr'] == 120

0 comments on commit 3492f44

Please sign in to comment.