Skip to content

Commit

Permalink
newest code for scoring and sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
SKDDJ committed Sep 26, 2023
1 parent 7d32d4e commit 05648e5
Show file tree
Hide file tree
Showing 15 changed files with 49 additions and 47 deletions.
Empty file modified .dockerignore
100644 → 100755
Empty file.
Empty file modified .vscode/launch.json
100644 → 100755
Empty file.
Empty file modified Dockerfile
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion configs/unidiffuserv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_config():

config.lr_scheduler = d(
name='customized',
warmup_steps=0
warmup_steps=20
)

# config.lr_scheduler = d(
Expand Down
Empty file modified eval_prompts/girl2_edit.json
100644 → 100755
Empty file.
4 changes: 2 additions & 2 deletions libs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def __init__(self,
]
for i in range(len(inst_img_path)):
path, text = inst_img_path[i]
if str(path).endswith('.jpg'):
inst_img_path[i] = (path, 'photo of a <new1> girl')
# if str(path).endswith('.jepg'):
# inst_img_path[i] = (path, 'photo of a <new1> girl')

self.instance_images_path.extend(inst_img_path)

Expand Down
8 changes: 4 additions & 4 deletions libs/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def tilde_beta(self, s, t):

def sample(self, x0): # sample from q(xn|x0), where n is uniform
if isinstance(x0, list):
# n = np.random.choice(list(range(1, self.N + 1)), (len(x0[0]),))
n = np.array([1000, 1000,1000, 1000])
n = np.random.choice(list(range(1, self.N + 1)), (len(x0[0]),))
# n = np.array([1000, 1000,1000, 1000])
eps = [torch.randn_like(tensor) for tensor in x0]
xn = [stp(self.cum_alphas[n] ** 0.5, tensor) + stp(self.cum_betas[n] ** 0.5, _eps) for tensor, _eps in zip(x0, eps)]
return torch.tensor(n), eps, xn
Expand Down Expand Up @@ -93,8 +93,8 @@ def LSimple_T2I(img, clip_img, text, data_type, nnet, schedule, device, config,
clip_img_n=clip_img_n.to(torch.float32)
t_text=torch.zeros_like(n, device=device)
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type
torch.save(clip_img_n, 'clip_img.pt')
torch.save(img_n, 'img.pt')
# torch.save(clip_img_n, 'clip_img.pt')
# torch.save(img_n, 'img.pt')

img_out, clip_img_out, text_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=t_text, data_type=data_type)

Expand Down
27 changes: 13 additions & 14 deletions libs/uvit_multi_post_ln_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def __init__(self, img_dim=1024, rank=24, text_dim=77, heads=8, qkv_bias=False,
nn.init.zeros_(self.to_k.weight)
nn.init.zeros_(self.to_v.weight)
nn.init.zeros_(self.to_out.weight)
nn.init.constant_(self.to_q.bias, 0)
nn.init.constant_(self.to_k.bias, 0)
nn.init.constant_(self.to_v.bias, 0)
nn.init.constant_(self.to_out.bias, 0)
# nn.init.constant_(self.to_q.bias, 0)
# nn.init.constant_(self.to_k.bias, 0)
# nn.init.constant_(self.to_v.bias, 0)
# nn.init.constant_(self.to_out.bias, 0)
def head_to_batch_dim(self, tensor, out_dim=3):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
Expand Down Expand Up @@ -190,10 +190,10 @@ def __init__(self, img_dim=1024, rank=24, text_dim=77, heads=8, qkv_bias=False,
nn.init.zeros_(self.to_k.weight)
nn.init.zeros_(self.to_v.weight)
nn.init.zeros_(self.to_out.weight)
nn.init.constant_(self.to_q.bias, 0)
nn.init.constant_(self.to_k.bias, 0)
nn.init.constant_(self.to_v.bias, 0)
nn.init.constant_(self.to_out.bias, 0)
# nn.init.constant_(self.to_q.bias, 0)
# nn.init.constant_(self.to_k.bias, 0)
# nn.init.constant_(self.to_v.bias, 0)
# nn.init.constant_(self.to_out.bias, 0)
def head_to_batch_dim(self, tensor, out_dim=3):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
Expand Down Expand Up @@ -656,8 +656,7 @@ def forward(self, img, clip_img, text, t_img, t_text, data_type):

x = x + pos_embed
x = self.pos_drop(x)
t_img_token, t_text_token, token_embed, text, clip_img, img = x.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
init_text = text

skips = []
count = 0
for blk in self.in_blocks:
Expand All @@ -668,8 +667,8 @@ def forward(self, img, clip_img, text, t_img, t_text, data_type):
modelttoi.to('cuda')
modelitot.to('cuda')

lora_img = modelttoi(img,init_text)
lora_text = modelitot(img,init_text)
lora_img = modelttoi(img,text)
lora_text = modelitot(img,text)
x = torch.cat((t_img_token, t_text_token, token_embed, text, clip_img, img), dim=1)
x = blk(x, skip = None, lora_input_img = lora_img,lora_input_text = lora_text)
count += 1
Expand All @@ -691,8 +690,8 @@ def forward(self, img, clip_img, text, t_img, t_text, data_type):
modelitot = self.adapters_itot[count]
modelttoi.to('cuda')
modelitot.to('cuda')
lora_img = modelttoi(img,init_text)
lora_text = modelitot(img,init_text)
lora_img = modelttoi(img,text)
lora_text = modelitot(img,text)
del y
x = blk(x, skip, lora_input_img = lora_img,lora_input_text = lora_text)
count += 1
Expand Down
Empty file modified output.txt
100644 → 100755
Empty file.
4 changes: 3 additions & 1 deletion run.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ docker run -it --gpus "device=${device}" --rm -v /home/test01/eval_prompts_advan

# sudo docker run -it --gpus all --rm -v /home/wuyujia/competition/eval_prompts_advance:/workspace/eval_prompts_advance \
# -v /home/wuyujia/competition/train_data:/workspace/train_data -v /home/wuyujia/competition/models:/workspace/models \
# -v /home/wuyujia/competition/indocker_shell.sh:/workspace/indocker_shell.sh -v /home/wuyujia/.cache/huggingface:/root/.cache/huggingface xiugou:v1
# -v /home/wuyujia/competition/indocker_shell.sh:/workspace/indocker_shell.sh -v /home/wuyujia/.cache/huggingface:/root/.cache/huggingface xiugou:v1


4 changes: 2 additions & 2 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def main(argv=None):
clip_text_strandard = FrozenCLIPEmbedder(version=config.clip_text_model, device=device).to("cpu")
total_diff_parameters += compare_model(clip_text_strandard, clip_text_model, clip_text_mapping_dict)
del clip_text_strandard

clip_text_model.load_textual_inversion(args.restore_path, token = "<new1>" , weight_name="<new1>.bin")
clip_text_model.to(device)
autoencoder.to(device)
Expand All @@ -344,7 +344,7 @@ def main(argv=None):
else:
prompt = prompt.replace("girl", "<new1> girl")

config.prompt = prompt
config.prompt = prompt
print("sampling with prompt:", prompt)
sample(prompt_index, config, nnet, clip_text_model, autoencoder, device)

Expand Down
Empty file modified sample_test_sparsity.py
100644 → 100755
Empty file.
6 changes: 3 additions & 3 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def score(dataset_base, prompts_base, outputs_base):
for taskname in SIM_TASKNAMES + EDIT_TASKNAMES:
task_prompt = os.path.join(prompts_base, f'{taskname}.json')
assert os.path.exists(task_prompt), f"Missing Prompt file: {task_prompt}"
task_output = os.path.join(outputs_base, f'{taskname}')
task_output = os.path.join(outputs_base, f'{taskname}_600')
assert os.path.exists(task_output), f"Missing Output folder: {task_output}"

def score_task(sample_folder, dataset_folder, prompt_json):
Expand Down Expand Up @@ -211,7 +211,7 @@ def score_task(sample_folder, dataset_folder, prompt_json):
for dataname, taskname in zip(DATANAMES, SIM_TASKNAMES):
task_dataset = os.path.join(dataset_base, f'{dataname}')
task_prompt = os.path.join(prompts_base, f'{taskname}.json')
task_output = os.path.join(outputs_base, f'{taskname}')
task_output = os.path.join(outputs_base, f'{taskname}_600')
score = score_task(task_output, task_dataset, task_prompt)
print(f"Score for task {taskname}: ", score)
sim_scores.append(score)
Expand All @@ -223,7 +223,7 @@ def score_task(sample_folder, dataset_folder, prompt_json):
task_dataset = os.path.join(dataset_base, f'{dataname}')
task_prompt = os.path.join(prompts_base, f'{taskname}.json')
print(taskname,"taskname")
task_output = os.path.join(outputs_base, f'{taskname}')
task_output = os.path.join(outputs_base, f'{taskname}_600')
score = score_task(task_output, task_dataset, task_prompt)
print(f"Score for task {taskname}: ", score)
edit_scores.append(score)
Expand Down
33 changes: 17 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,24 +465,25 @@ def loop():
# wandb.log(utils.add_prefix(metrics, 'train'), step=total_step)
# train_state.save(os.path.join(config.log_dir, f'{total_step:04}.ckpt'))
log_step += config.log_interval
if total_step == 1200:
logging.info(f"saving final ckpts to {config.outdir}1200...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + "_1200")
# if total_step == 800:
# logging.info(f"saving final ckpts to {config.outdir}_800...")
# save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + "_800")
# # train_state.save(os.path.join(config.outdir, 'final.ckpt'))
# train_state.save_lora(os.path.join(config.outdir , 'lora.pt.tmp'))

if total_step == 1000:
logging.info(f"saving final ckpts to {config.outdir}_1000...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + "_1000")
# train_state.save(os.path.join(config.outdir, 'final.ckpt'))
train_state.save_lora(os.path.join(config.outdir + "_1200", 'lora.pt.tmp'))

if total_step == 2400:
logging.info(f"saving final ckpts to {config.outdir}2400...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + "_2400")
# train_state.save(os.path.join(config.outdir, 'final.ckpt'))
train_state.save_lora(os.path.join(config.outdir + "_2400", 'lora.pt.tmp'))
train_state.save_lora(os.path.join(config.outdir + "_1000", 'lora.pt.tmp'))

if total_step == 3000:
logging.info(f"saving final ckpts to {config.outdir}3000...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + "_3000")
if total_step == 1500:
logging.info(f"saving final ckpts to {config.outdir}_1500...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir + "_1500")
# train_state.save(os.path.join(config.outdir, 'final.ckpt'))
train_state.save_lora(os.path.join(config.outdir + "_3000", 'lora.pt.tmp'))
break
train_state.save_lora(os.path.join(config.outdir + "_1500", 'lora.pt.tmp'))


# if total_step >= eval_step:
# eval(total_step)
# eval_step += config.eval_interval
Expand All @@ -492,7 +493,7 @@ def loop():
# train_state.save(os.path.join(config.ckpt_root, f'{total_step:04}.ckpt'))
# save_step += config.save_interval

if total_step >= 800:
if total_step >= 2000:
logging.info(f"saving final ckpts to {config.outdir}...")
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.outdir)
# train_state.save(os.path.join(config.outdir, 'final.ckpt'))
Expand Down
8 changes: 4 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import itertools
from libs.clip import CLIPEmbedder
from peft import inject_adapter_in_model, LoraConfig,get_peft_model
# lora_config = LoraConfig(
# inference_mode=False, r=128, lora_alpha=90, lora_dropout=0.1,target_modules=["qkv","fc1","fc2","proj","to_out","to_q","to_k","to_v","text_embed","clip_img_embed"]
# )
lora_config = LoraConfig(
inference_mode=False, r=128, lora_alpha=90, lora_dropout=0.1,target_modules=["qkv","fc1","fc2","proj","text_embed","clip_img_embed"]
inference_mode=False, r=128, lora_alpha=90, lora_dropout=0.1,target_modules=["qkv","fc1","fc2","proj","to_out","to_q","to_k","to_v","text_embed","clip_img_embed"]
)
# lora_config = LoraConfig(
# inference_mode=False, r=128, lora_alpha=90, lora_dropout=0.1,target_modules=["qkv","fc1","fc2","proj","text_embed","clip_img_embed"]
# )
# lora_config = LoraConfig(
# inference_mode=False, r=128, lora_alpha=64, lora_dropout=0.1,target_modules=["qkv","to_out","to_q","to_k","to_v","text_embed","clip_img_embed"]
# )#94,838,784
# lora_config = LoraConfig(
Expand Down

0 comments on commit 05648e5

Please sign in to comment.