Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use less RAM when creating models #11958

Merged
merged 1 commit into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
Expand Down
106 changes: 101 additions & 5 deletions modules/sd_disable_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,31 @@
import torch
import transformers.utils.hub

from modules import shared

class DisableInitialization:

class ReplaceHelper:
def __init__(self):
self.replaced = []

def replace(self, obj, field, func):
original = getattr(obj, field, None)
if original is None:
return None

self.replaced.append((obj, field, original))
setattr(obj, field, func)

return original

def restore(self):
for obj, field, original in self.replaced:
setattr(obj, field, original)

self.replaced.clear()
Comment on lines +9 to +27
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be implemented using contextlib.ExitStack too.



class DisableInitialization(ReplaceHelper):
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
Expand All @@ -21,7 +44,7 @@ class DisableInitialization:
"""

def __init__(self, disable_clip=True):
self.replaced = []
super().__init__()
self.disable_clip = disable_clip

def replace(self, obj, field, func):
Expand Down Expand Up @@ -86,8 +109,81 @@ def transformers_configuration_utils_cached_file(url, *args, local_files_only=Fa
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)

def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced:
setattr(obj, field, original)
self.restore()

self.replaced.clear()

class InitializeOnMeta(ReplaceHelper):
"""
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
which results in those parameters having no values and taking no memory. model.to() will be broken and
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.

Usage:
```
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
```
"""

def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return

def set_device(x):
x["device"] = "meta"
return x

linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)

def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()


class LoadStateDictOnMeta(ReplaceHelper):
"""
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
Meant to be used together with InitializeOnMeta above.

Usage:
```
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
model.load_state_dict(state_dict, strict=False)
```
"""

def __init__(self, state_dict, device):
super().__init__()
self.state_dict = state_dict
self.device = device

def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return

sd = self.state_dict
device = self.device

def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]

for name, param in params:
if param.is_meta:
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)

original(self, state_dict, prefix, *args, **kwargs)

for name, _ in params:
key = prefix + name
if key in sd:
del sd[key]

linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))

def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
16 changes: 10 additions & 6 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])



def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
Expand Down Expand Up @@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
sd_model = instantiate_from_config(sd_config.model)
except Exception:
pass
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)

except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True)

if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)

with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)

sd_model.used_config = checkpoint_config

timer.record("create model")

load_model_weights(sd_model, checkpoint_info, state_dict, timer)
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)

if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
Expand Down
4 changes: 2 additions & 2 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def load_model():
if modules.sd_hijack.current_optimizer is None:
modules.sd_hijack.apply_optimizations()

Thread(target=load_model).start()
devices.first_time_calculation()

Thread(target=devices.first_time_calculation).start()
Thread(target=load_model).start()

shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
Expand Down