From eef8bb2b1b6f0875ab0581079e1511d51654910e Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Fri, 27 Sep 2024 20:34:19 -0700 Subject: [PATCH] chore: fixing type hint checkpointing class (#590) This PR fixes the type hint of some classes in the checkpointing code. --- torchtitan/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index b71419c6..266f689c 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -84,7 +84,7 @@ class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model - def state_dict(self) -> None: + def state_dict(self) -> Dict[str, Any]: return { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() } @@ -107,7 +107,7 @@ def __init__( self.model = [model] if isinstance(model, nn.Module) else model self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim - def state_dict(self) -> None: + def state_dict(self) -> Dict[str, Any]: func = functools.partial( get_optimizer_state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True),