diff --git a/neps/state/filebased.py b/neps/state/filebased.py index f6324860..eb9efe21 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -130,8 +130,8 @@ class ReaderWriterSeedSnapshot: PY_RNG_TUPLE_FILENAME: ClassVar = "py_rng.npy" NP_RNG_STATE_FILENAME: ClassVar = "np_rng_state.npy" - TORCH_RNG_STATE_FILENAME: ClassVar = "torch_rng_state.pt" - TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.pt" + TORCH_RNG_STATE_FILENAME: ClassVar = "torch_rng_state.npy" + TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.npy" SEED_INFO_FILENAME: ClassVar = "seed_info.json" @classmethod @@ -159,7 +159,9 @@ def read(cls, directory: Path) -> SeedSnapshot: import torch if torch_rng_path_exists: - torch_rng_state = torch.load(torch_rng_path, weights_only=True) + # OPTIM: This ends up being much faster to go to numpy + _bytes = np.fromfile(torch_rng_path, dtype=np.uint8) + torch_rng_state = torch.tensor(_bytes, dtype=torch.uint8) if torch_cuda_rng_path_exists: # By specifying `weights_only=True`, it disables arbitrary object loading @@ -211,7 +213,8 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: if snapshot.torch_rng is not None: import torch - torch.save(snapshot.torch_rng, torch_rng_path) + # OPTIM: This ends up being much faster to go to numpy + snapshot.torch_rng.numpy().tofile(torch_rng_path) if snapshot.torch_cuda_rng is not None: import torch