Skip to content

Commit

Permalink
A change in order to pass black check
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Aug 31, 2022
1 parent 2842ded commit accd3bd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
Expand All @@ -436,7 +436,7 @@ def load(module: nn.Module, prefix=""):
if child is not None:
load(child, prefix + name + ".")

load(model_to_load, prefix=start_prefix)
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it. We don't call `gc.collect` here, instead let GC make its own decision.
# See https://github.com/huggingface/transformers/issues/18782
Expand Down

0 comments on commit accd3bd

Please sign in to comment.