Skip to content

Commit

Permalink
optim: save torch tensors as numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Dec 2, 2024
1 parent bb18db7 commit 363c94d
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions neps/state/filebased.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class ReaderWriterSeedSnapshot:

PY_RNG_TUPLE_FILENAME: ClassVar = "py_rng.npy"
NP_RNG_STATE_FILENAME: ClassVar = "np_rng_state.npy"
TORCH_RNG_STATE_FILENAME: ClassVar = "torch_rng_state.pt"
TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.pt"
TORCH_RNG_STATE_FILENAME: ClassVar = "torch_rng_state.npy"
TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.npy"
SEED_INFO_FILENAME: ClassVar = "seed_info.json"

@classmethod
Expand Down Expand Up @@ -159,7 +159,9 @@ def read(cls, directory: Path) -> SeedSnapshot:
import torch

if torch_rng_path_exists:
torch_rng_state = torch.load(torch_rng_path, weights_only=True)
# OPTIM: This ends up being much faster to go to numpy
_bytes = np.fromfile(torch_rng_path, dtype=np.uint8)
torch_rng_state = torch.tensor(_bytes, dtype=torch.uint8)

if torch_cuda_rng_path_exists:
# By specifying `weights_only=True`, it disables arbitrary object loading
Expand Down Expand Up @@ -211,7 +213,8 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None:
if snapshot.torch_rng is not None:
import torch

torch.save(snapshot.torch_rng, torch_rng_path)
# OPTIM: This ends up being much faster to go to numpy
snapshot.torch_rng.numpy().tofile(torch_rng_path)

if snapshot.torch_cuda_rng is not None:
import torch
Expand Down

0 comments on commit 363c94d

Please sign in to comment.