Skip to content

Commit

Permalink
Further fix the singular-leaf checkpointing, and add tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 463474218
  • Loading branch information
IvyZX authored and Flax Authors committed Jul 27, 2022
1 parent 2811ae5 commit bb8f222
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
7 changes: 4 additions & 3 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def _checkpoint_path_step(path: str) -> Optional[float]:


def _split_gdas(
target: Dict[str, Any]) -> Tuple[Dict[str, Any], List[GlobalDeviceArray]]:
target: Dict[str, Any]
) -> Tuple[Dict[str, Any], List[Tuple[GlobalDeviceArray, str]]]:
# When target is a single leaf instead of a pytree dict.
if not isinstance(target, (core.FrozenDict, dict)):
if isinstance(target, GlobalDeviceArray):
return GDA_PH, [target]
return GDA_PH, [(target, '')]
return target, []
# Traverse the target and handle GlobalDeviceArrays.
flattened = traverse_util.flatten_dict(target, keep_empty_nodes=True)
Expand Down Expand Up @@ -119,7 +120,7 @@ def _restore_gdas(state_dict,
# When target is a single leaf instead of a pytree dict.
if not isinstance(state_dict, (core.FrozenDict, dict)):
if isinstance(target, GlobalDeviceArray) and isinstance(
state_dict, GlobalDeviceArray):
state_dict, str) and state_dict.startswith(GDA_PH):
if not gda_manager:
raise errors.GDACheckpointingRequiredError(ckpt_path, step)
if not target:
Expand Down
11 changes: 11 additions & 0 deletions tests/checkpoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ def test_save_restore_checkpoints_target_none(self):
expected_new_object = {str(k): v for k, v in enumerate(test_object1)}
jtu.check_eq(new_object, expected_new_object)

def test_save_restore_checkpoints_target_singular(self):
tmp_dir = self.create_tempdir().full_path
test_object0 = np.array([0, 0, 0], np.int32)
test_object1 = np.array([1, 1, 1], np.int32)
checkpoints.save_checkpoint(tmp_dir, test_object1, 0)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=None)
jtu.check_eq(new_object, test_object1)
checkpoints.save_checkpoint(tmp_dir, test_object0, 1)
new_object = checkpoints.restore_checkpoint(tmp_dir, target=test_object1)
jtu.check_eq(new_object, test_object0)

def test_async_save_checkpoints(self):
tmp_dir = pathlib.Path(self.create_tempdir().full_path)
test_object0 = {'a': np.array([0, 0, 0], np.int32),
Expand Down

0 comments on commit bb8f222

Please sign in to comment.