Skip to content

Commit

Permalink
send tensors to the correct device when loading from safetensors file…
Browse files Browse the repository at this point in the history
… with memmap disabled for AUTOMATIC1111#11260
  • Loading branch information
AUTOMATIC1111 authored and lsjspl committed Jul 11, 2023
1 parent df37879 commit cc98fb1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,13 @@ def read_metadata_from_safetensors(filename):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()

if not shared.opts.disable_mmap_load_safetensors:
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)

Expand Down
2 changes: 1 addition & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def list_samplers():
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files (fixes very slow loading speed in some cases)."),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
}))

options_templates.update(options_section(('training', "Training"), {
Expand Down

0 comments on commit cc98fb1

Please sign in to comment.