Skip to content

Commit

Permalink
update sample config
Browse files Browse the repository at this point in the history
  • Loading branch information
SKDDJ committed Aug 31, 2023
1 parent 01f4a38 commit c35ba0b
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 53 deletions.
3 changes: 3 additions & 0 deletions configs/sample_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def get_config():


# sample
config.mode = "t2i"
config.n_samples = 9 # control the numbers of generating images
config.n_iter = 1 # 过多的迭代次数可能导致过拟合或生成的样本过于接近训练数据
config.sample = d(
sample_steps=30,
scale=7.,
Expand Down
13 changes: 7 additions & 6 deletions configs/unidiffuserv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_config():
config.save_interval = 300
# config.max_step = 400

config.max_step = 250
config.max_step = 1000
config.batch_size = 1

config.center_crop = True
Expand All @@ -36,7 +36,7 @@ def get_config():

config.dataloader_num_workers = 10

config.sample_batch_size = 4
# config.sample_batch_size = 4 it seems uesless...
config.revision = None
config.num_workers = 10
# config.batch_size = 6
Expand Down Expand Up @@ -95,14 +95,15 @@ def get_config():


# sample

config.mode = "t2i"
config.n_samples = 2
config.n_iter = 6
config.n_samples = 9 # control the numbers of generating images
config.n_iter = 1 # 过多的迭代次数可能导致过拟合或生成的样本过于接近训练数据
config.nrow = 4
config.sample = d(
sample_steps=30,
sample_steps=100, # 我从 30 调到了 100
scale=7.,
t2i_cfg_mode='true_uncond'
t2i_cfg_mode='true_uncond' # 笑死,之前用的是true_uncond模式,生成的图片能看才见鬼
)

return config
3 changes: 3 additions & 0 deletions libs/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def load_textual_inversion(
subfolder=subfolder,
user_agent=user_agent,
)
print("开始加载safetensor")
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
Expand Down Expand Up @@ -279,13 +280,15 @@ def load_textual_inversion(

# add tokens and get ids
self.tokenizer.add_tokens(tokens)

token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
token_ids_and_embeddings += zip(token_ids, embeddings)

# logger.info(f"Loaded textual inversion embedding for {token}.")

# resize token embeddings and set all new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer))

for token_id, embedding in token_ids_and_embeddings:
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding

Expand Down
13 changes: 9 additions & 4 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def t2i_nnet(x, timesteps, text): # text is the low dimension version of the te
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
3. return linear combination of conditional output and unconditional output
'empty_token' 模式:在这种模式下,使用原始配置和空字符串来生成图像。这意味着生成的图像不受与之相关的文本信息的约束,生成的结果更加自由和多样化。
'true_uncond' 模式:在这种模式下,使用通过我们方法学习到的无条件模型来生成图像。这意味着生成的图像不依赖于与之相关的文本信息,生成的结果更加无条件和独立。
"""
z, clip_img = split(x)

Expand Down Expand Up @@ -276,11 +280,12 @@ def main(argv=None):
set_seed(42)
config = get_config()
args = get_args()

# config.n_iter = 6
# config.n_samples = 9

config.output_path = args.output_path
config.nnet_path = os.path.join(args.restore_path, "final.ckpt",'nnet.pth')
# config.nnet_path = args.restore_path
config.n_samples = 3
config.n_iter = 1
device = "cuda"

# init models
Expand All @@ -290,7 +295,7 @@ def main(argv=None):
autoencoder = libs.autoencoder.get_model(**config.autoencoder)
clip_text_model = FrozenCLIPEmbedder(version=config.clip_text_model, device=device)
clip_text_model.to(device)
clip_text_model.load_textual_inversion(args.weight_dir, weight_name="<new1>.bin")
clip_text_model.load_textual_inversion(args.weight_dir, token = "<new1>" , weight_name="<new1>.bin")

nnet_mapping_dict = {}
autoencoder_mapping_dict = {}
Expand Down
1 change: 1 addition & 0 deletions sample.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
python sample.py --restore_path model_output/boy1 --prompt_path eval_prompts/boy1.json --output_path outputs/boy1
python sample.py --restore_path model_output/boy2 --prompt_path eval_prompts/boy2.json --output_path outputs/boy2
python sample.py --restore_path model_output/girl1 --prompt_path eval_prompts/girl1.json --output_path outputs/girl1

python sample.py --restore_path model_output/girl2 --prompt_path eval_prompts/girl2.json --output_path outputs/girl2 --weight_dir model_output/girl2


Expand Down
94 changes: 57 additions & 37 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,19 @@ def train(config):
concepts_list[i] = concept
accelerator.wait_for_everyone()


train_state = utils.initialize_train_state(config, device, uvit_class=UViT)
pretrained_model_name_or_path = "/home/schengwei/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce"
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
revision = None,
use_fast=False,
)
text_encoder_cls = import_model_class_from_model_name_or_path(pretrained_model_name_or_path , config.revision)
text_encoder = text_encoder_cls.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision
)
text_encoder.to(device)
train_state = utils.initialize_train_state(config, device, uvit_class=UViT,text_encoder = text_encoder)
logging.info(f'load nnet from {config.nnet_path}')
train_state.nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'), False)

Expand All @@ -165,19 +176,8 @@ def train(config):
# Modify the code of custom diffusion to directly import the clip text encoder
# instead of freezing all parameters.
# clip_text_model = CLIPEmbedder(version=config.clip_text_model, device=device)
pretrained_model_name_or_path = "/home/schengwei/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce"
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
revision = None,
use_fast=False,
)
text_encoder_cls = import_model_class_from_model_name_or_path(pretrained_model_name_or_path , config.revision)
text_encoder = text_encoder_cls.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision
)
text_encoder.to(device)
clip_text_model = CLIPEmbedder(version=config.clip_text_model, device=device)

# clip_text_model = CLIPEmbedder(version=config.clip_text_model, device=device)
clip_img_model, clip_img_model_preprocess = clip.load(config.clip_img_model, jit=False)
clip_img_model.to(device).eval().requires_grad_(False)

Expand Down Expand Up @@ -209,7 +209,7 @@ def train(config):
)

# Convert the initializer_token, placeholder_token to ids
token_ids = clip_text_model.tokenizer.encode([initializer_token], add_special_tokens=False)
token_ids = tokenizer.encode([initializer_token], add_special_tokens=False)

#[42170]
#ktn
Expand All @@ -220,7 +220,7 @@ def train(config):

initializer_token_id.append(token_ids[0])
modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token))

print("modifier_token_id",modifier_token_id)


# Resize the token embeddings as we are adding new special tokens to the tokenizer
Expand Down Expand Up @@ -265,7 +265,6 @@ def train(config):
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=config.dataloader_num_workers,
pin_memory=True,
drop_last=True
)

Expand All @@ -280,7 +279,26 @@ def train(config):
_betas = stable_diffusion_beta_schedule()
schedule = Schedule(_betas)
logging.info(f'use {schedule}')

for name, param in nnet.named_parameters():
param.requires_grad = True
for name, param in nnet.named_parameters():
if 'lora_adapters_itot' not in name and 'lora_adapters_ttoi' not in name:
param.requires_grad = False
for name, param in nnet.named_parameters():
if 'text_embed' in name or 'token_embedding' in name:
param.requires_grad = True

# 验证哪些参数被冻结
for name, param in text_encoder.named_parameters():
if param.requires_grad:
print(f"未冻结的参数: {name}")

total_frozen_params = sum(p.numel() for p in text_encoder.parameters() if p.requires_grad)
print("未冻结参数量:",total_frozen_params)
# 77560320 lora_adapter+text_embedding 37946112 token_embedding
# INFO - nnet has 1029970000 parameters
# INFO - text_encoder has 123060480 parameters

def train_step():
metrics = dict()

Expand All @@ -307,7 +325,7 @@ def train_step():

bloss = LSimple_T2I(img=z,
clip_img=clip_img, text=text, data_type=data_type, nnet=nnet, schedule=schedule, device=device, config=config,mask=mask)
bloss.requires_grad = True
# bloss.requires_grad = True

accelerator.backward(bloss)
# Zero out the gradients for all token embeddings except the newly added
Expand Down Expand Up @@ -336,15 +354,16 @@ def train_step():
else nnet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
itertools.chain(text_encoder.get_input_embeddings().parameters(), nnet.parameters())
if args.modifier_token is not None
else nnet.parameters(),
)
for name, param in nnet.named_parameters():
if param.grad is None:
print(name)
for name, param in text_encoder.named_parameters():
if param.grad is None:
print(name)
exit()
# 更新参数
optimizer.step()

lr_scheduler.step()

train_state.ema_update(config.get('ema_rate', 0.9999))
Expand All @@ -358,24 +377,25 @@ def train_step():
metrics['lr'] = train_state.optimizer.param_groups[0]['lr']
return metrics

@torch.no_grad()
@torch.autocast(device_type='cuda')
def eval(total_step):
"""
write evaluation code here
"""
# @torch.no_grad()
# @torch.autocast(device_type='cuda')
# def eval(total_step):
# """
# write evaluation code here
# """

return
# return

def loop():
log_step = 0
eval_step = 0
eval_step = 100000
save_step = config.save_interval

while True:
nnet.train()
with accelerator.accumulate(nnet):
metrics = train_step()
print("metrics",metrics)
accelerator.wait_for_everyone()

if accelerator.is_main_process:
Expand All @@ -397,7 +417,7 @@ def loop():



if total_step >= config.max_step:
if total_step >= 500:
logging.info(f"saving final ckpts to {config.outdir}...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir)
train_state.save(os.path.join(config.outdir, 'final.ckpt'))
Expand Down
13 changes: 7 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,22 @@ def to(self, device):
def cnt_params(model):
return sum(param.numel() for param in model.parameters())


def initialize_train_state(config, device, uvit_class):
clip_text_model = CLIPEmbedder(version=config.clip_text_model, device=device)
def initialize_train_state(config, device, uvit_class,text_encoder = None):

params = []
nnet = uvit_class(**config.nnet)
params = list(itertools.chain(clip_text_model.transformer.get_input_embeddings().parameters(), nnet.lora_adapters_itot.parameters(), nnet.lora_adapters_ttoi.parameters()))

params = list(itertools.chain(text_encoder.get_input_embeddings().parameters(), nnet.lora_adapters_itot.parameters(), nnet.lora_adapters_ttoi.parameters()))
nnet_ema = uvit_class(**config.nnet)
nnet_ema.eval()
logging.info(f'nnet has {cnt_params(nnet)} parameters')

logging.info(f'text_encoder has {cnt_params(text_encoder)} parameters')

optimizer = get_optimizer(params, **config.optimizer)
lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)

train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
nnet=nnet, nnet_ema=nnet_ema, text_embedding=clip_text_model.transformer.get_input_embeddings())
nnet=nnet, nnet_ema=nnet_ema, text_embedding=text_encoder.get_input_embeddings())
train_state.ema_update(0)
train_state.to(device)
# no need to resume
Expand Down

0 comments on commit c35ba0b

Please sign in to comment.