Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device support improvements (MPS) #1054

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor memory cleaning into a single function
akx committed Jan 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit afc38707d57a055577627ee4e17ade4581ed0140
6 changes: 2 additions & 4 deletions fine_tune.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
# XXX dropped option: hypernetwork training

import argparse
import gc
import math
import os
from multiprocessing import Value
@@ -11,6 +10,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -158,9 +158,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

7 changes: 3 additions & 4 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
@@ -66,6 +66,7 @@
import numpy as np
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -888,8 +889,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
@@ -1047,8 +1047,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
9 changes: 9 additions & 0 deletions library/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import gc

import torch


def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
5 changes: 2 additions & 3 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import math
import os
from typing import Optional
@@ -8,6 +7,7 @@
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.device_utils import clean_memory
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline

TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
@@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
10 changes: 4 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,6 @@
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
import os
@@ -67,6 +66,7 @@

# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel

# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -2278,8 +2278,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent

# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()


def cache_batch_text_encoder_outputs(
@@ -4006,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4816,7 +4814,7 @@ def sample_images_common(

# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
clean_memory()

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
10 changes: 4 additions & 6 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import numpy as np
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -640,8 +641,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
@@ -780,8 +780,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents.to(self.vae.dtype)).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
@@ -796,8 +795,7 @@ def __call__(
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

if output_type == "pil":
# image = self.numpy_to_pil(image)
9 changes: 3 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# training with captions

import argparse
import gc
import math
import os
from multiprocessing import Value
@@ -11,6 +10,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -252,9 +252,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

@@ -407,8 +405,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
9 changes: 3 additions & 6 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward

import argparse
import gc
import json
import math
import os
@@ -15,6 +14,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -164,9 +164,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

@@ -291,8 +289,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
9 changes: 3 additions & 6 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
@@ -12,6 +11,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -163,9 +163,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

@@ -264,8 +262,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
7 changes: 3 additions & 4 deletions sdxl_train_network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -65,8 +66,7 @@ def cache_text_encoder_outputs_if_needed(
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
@@ -81,8 +81,7 @@ def cache_text_encoder_outputs_if_needed(

text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

if not args.lowram:
print("move vae and unet back to original device")
6 changes: 2 additions & 4 deletions train_controlnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
@@ -12,6 +11,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -219,9 +219,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

6 changes: 2 additions & 4 deletions train_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# DreamBooth training
# XXX dropped option: fine_tune

import gc
import argparse
import itertools
import math
@@ -12,6 +11,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -138,9 +138,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

6 changes: 2 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import argparse
import gc
import math
import os
import sys
@@ -14,6 +13,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -266,9 +266,7 @@ def train(self, args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

6 changes: 2 additions & 4 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import math
import os
from multiprocessing import Value
@@ -8,6 +7,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -363,9 +363,7 @@ def train(self, args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

6 changes: 2 additions & 4 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import argparse
import gc
import math
import os
import toml
@@ -9,6 +8,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
@@ -286,9 +286,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()