From e1ad633b76d6a983dc5aa3eb8e93640dee2afeb1 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Wed, 3 Jan 2024 11:57:26 -0800 Subject: [PATCH] Use mmap option to load_state_dict --- src/transformers/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 05d74d654252..6b428d8fc475 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -515,8 +515,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): map_location = "meta" else: map_location = "cpu" - - return torch.load(checkpoint_file, map_location=map_location, weights_only=True) + extra_args = {} + if version.parse(torch.__version__) >= version.parse("2.1.0"): + extra_args = {'mmap':True} + return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args) except Exception as e: try: with open(checkpoint_file) as f: