-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_edlora.py
112 lines (86 loc) · 4.42 KB
/
test_edlora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import os.path as osp
from re import T
import torch
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import DPMSolverMultistepScheduler
from diffusers.utils import check_min_version
from omegaconf import OmegaConf
from tqdm import tqdm
from lib.data.prompt_dataset import PromptDataset
from lib.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline
from lib.utils.convert_edlora_to_diffusers import convert_edlora
from lib.utils.util import NEGATIVE_PROMPT, compose_visualize, dict2str, pil_imwrite, set_path_logger
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version('0.18.2')
def visual_validation(accelerator, pipe, dataloader, current_iter, opt):
dataset_name = dataloader.dataset.opt['name']
pipe.unet.eval()
pipe.text_encoder.eval()
for idx, val_data in enumerate(tqdm(dataloader)):
output = pipe(
prompt=val_data['prompts'],
latents=val_data['latents'].to(dtype=torch.float16),
negative_prompt=[NEGATIVE_PROMPT] * len(val_data['prompts']),
num_inference_steps=opt['val']['sample'].get('num_inference_steps', 50),
guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5),
no_cross_attention_kwargs = True
).images
for img, prompt, indice in zip(output, val_data['prompts'], val_data['indices']):
img_name = '{prompt}---G_{guidance_scale}_S_{steps}---{indice}'.format(
prompt=prompt.replace(' ', '_'),
guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5),
steps=opt['val']['sample'].get('num_inference_steps', 50),
indice=indice)
save_img_path = osp.join(opt['path']['visualization'], dataset_name, f'{current_iter}', f'{img_name}---{current_iter}.png')
pil_imwrite(img, save_img_path)
# tentative for out of GPU memory
del output
torch.cuda.empty_cache()
# Save the lora layers, final eval
accelerator.wait_for_everyone()
if opt['val'].get('compose_visualize'):
if accelerator.is_main_process:
compose_visualize(os.path.dirname(save_img_path))
def test(root_path, args):
# load config
opt = OmegaConf.to_container(OmegaConf.load(args.opt), resolve=True)
# set accelerator, mix-precision set in the environment by "accelerate config"
accelerator = Accelerator(mixed_precision=opt['mixed_precision'])
# set experiment dir
with accelerator.main_process_first():
set_path_logger(accelerator, root_path, args.opt, opt, is_train=False)
# get logger
logger = get_logger('mixofshow', log_level='INFO')
logger.info(accelerator.state, main_process_only=True)
logger.info(dict2str(opt))
# If passed along, set the training seed now.
if opt.get('manual_seed') is not None:
set_seed(opt['manual_seed'])
# Get the training dataset
valset_cfg = opt['datasets']['val_vis']
val_dataset = PromptDataset(valset_cfg)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=valset_cfg['batch_size_per_gpu'], shuffle=False)
enable_edlora = opt['models']['enable_edlora']
for lora_alpha in opt['val']['alpha_list']:
pipeclass = EDLoRAPipeline if enable_edlora else StableDiffusionPipeline
pipe = pipeclass.from_pretrained(opt['models']['pretrained_path'],
scheduler=DPMSolverMultistepScheduler.from_pretrained(opt['models']['pretrained_path'], subfolder='scheduler'),
torch_dtype=torch.float16).to('cuda')
pipe, new_concept_cfg = convert_edlora(pipe, torch.load(opt['path']['lora_path']), enable_edlora=enable_edlora, alpha=lora_alpha)
pipe.set_new_concept_cfg(new_concept_cfg)
# visualize embedding + LoRA weight shift
logger.info(f'Start validation sample lora({lora_alpha}):')
lora_type = 'edlora' if enable_edlora else 'lora'
visual_validation(accelerator, pipe, val_dataloader, f'validation_{lora_type}_{lora_alpha}', opt)
del pipe
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, default='options/test/EDLoRA/EDLoRA_hina_Anyv4_B4_Iter1K.yml')
args = parser.parse_args()
root_path = osp.abspath(osp.join(__file__, osp.pardir))
test(root_path, args)