diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9ca940c62d4f2b..7c8efdff1a3fbf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -434,7 +434,7 @@ def load(module: nn.Module, state_dict, prefix=""): for name, child in module._modules.items(): if child is not None: - load(child, prefix + name + ".") + load(child, state_dict, prefix + name + ".") load(model_to_load, state_dict, prefix=start_prefix) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so