From ade4ab1fbd4641d9c42d8a4551b7e6494ec7378d Mon Sep 17 00:00:00 2001 From: bipinkrishnan Date: Sat, 11 May 2024 19:34:14 +0100 Subject: [PATCH 1/2] Add feature to save step when saving state --- src/accelerate/accelerator.py | 4 +++- src/accelerate/checkpointing.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index c109365b13a..b62ffedde27 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2952,6 +2952,7 @@ def _inner(folder): schedulers, dataloaders, self.state.process_index, + self.step, self.scaler, save_on_each_node=self.project_configuration.save_on_each_node, safe_serialization=safe_serialization, @@ -3099,7 +3100,7 @@ def _inner(folder): else: map_location = "cpu" - load_accelerator_state( + override_attributes = load_accelerator_state( input_dir, models, optimizers, @@ -3110,6 +3111,7 @@ def _inner(folder): map_location, **load_model_func_kwargs, ) + self.step = override_attributes["step"] custom_checkpoints = [ f for f in os.listdir(input_dir) if re.search(r"^custom_checkpoint_\d+\.pkl$", f) is not None ] diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index d7688199041..89e81955aaa 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -55,6 +55,7 @@ def save_accelerator_state( schedulers: list, dataloaders: list, process_index: int, + step: int, scaler: GradScaler = None, save_on_each_node: bool = False, safe_serialization: bool = True, @@ -82,6 +83,8 @@ def save_accelerator_state( A list of dataloader instances to save their sampler states process_index (`int`): The current process index in the Accelerator state + step (`int`): + The current step in the internal step tracker scaler (`torch.cuda.amp.GradScaler`, *optional*): An optional gradient scaler instance to save save_on_each_node (`bool`, *optional*): @@ -134,6 +137,7 @@ def save_accelerator_state( # Random number generator states states = {} states_name = f"{RNG_STATE_NAME}_{process_index}.pkl" + states["step"] = step states["random_state"] = random.getstate() states["numpy_random_seed"] = np.random.get_state() states["torch_manual_seed"] = torch.get_rng_state() @@ -181,6 +185,8 @@ def load_accelerator_state( load_model_func_kwargs (`dict`, *optional*): Additional arguments that can be passed to the model's `load_state_dict` method. """ + # stores the `Accelerator` attributes to override + override_attributes = dict() if map_location not in [None, "cpu", "on_device"]: raise TypeError( "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`" @@ -240,6 +246,7 @@ def load_accelerator_state( # Random states try: states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl")) + override_attributes["step"] = states["step"] random.setstate(states["random_state"]) np.random.set_state(states["numpy_random_seed"]) torch.set_rng_state(states["torch_manual_seed"]) @@ -253,6 +260,8 @@ def load_accelerator_state( except Exception: logger.info("Could not load random states") + return override_attributes + def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False): """ From ff08093b6db8b1491308b8a59481f52b27587e31 Mon Sep 17 00:00:00 2001 From: bipinkrishnan Date: Thu, 13 Jun 2024 12:57:49 +0100 Subject: [PATCH 2/2] Update docstring for `load_accelerate_state` --- src/accelerate/checkpointing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 89e81955aaa..b707e34fdd5 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -184,6 +184,9 @@ def load_accelerator_state( What device to load the optimizer state onto. Should be one of either "cpu" or "on_device". load_model_func_kwargs (`dict`, *optional*): Additional arguments that can be passed to the model's `load_state_dict` method. + + Returns: + `dict`: Contains the `Accelerator` attributes to override while loading the state. """ # stores the `Accelerator` attributes to override override_attributes = dict()