Skip to content

Commit

Permalink
code
Browse files Browse the repository at this point in the history
  • Loading branch information
SKDDJ committed Aug 29, 2023
1 parent ab8d649 commit b8ba952
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 39 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ models/
logs/
log*/
*__pycache__
model_output/
model_output/*
outputs/
model_output_test/

output_test/
real_reg/
real_reg/*

train_data/
train_data/*


## weights anf lib
Expand Down
24 changes: 11 additions & 13 deletions libs/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,30 +89,28 @@ def LSimple_T2I(img, clip_img, text, data_type, nnet, schedule, device, config,
img_n, clip_img_n = xn
n = n.to(device)
clip_img_n=clip_img_n.to(torch.float32)
t_text=torch.zeros_like(n, device=device)
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type
img_out, clip_img_out, text_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=t_text, data_type=data_type)

print(clip_img_n.dtype)
print(img_n.dtype)

img_out, clip_img_out, text_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=torch.zeros_like(n, device=device), data_type=data_type)



img_out, img_out_prior = torch.chunk(img_out, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)

mask = torch.chunk(mask, 2, dim=0)[0]
# Compute instance loss
aloss = F.mse_loss(img_out.float(), target.float(), reduction="none")
aloss = ((aloss*mask).sum([1,2,3]) / mask.sum([1,2,3])).mean()

aloss = F.mse_loss(img_out.float(), target.float(), reduction="mean")

# aloss = ((aloss*mask).sum([1,2,3]) / mask.sum([1,2,3])).mean()

# Compute prior loss
prior_loss = F.mse_loss(img_out_prior.float(), target_prior.float(), reduction="mean")



bloss = aloss + config.prior_loss_weight * prior_loss


return bloss, 0.
bloss = aloss + config.prior_loss_weight * prior_loss

return bloss



Expand Down
17 changes: 15 additions & 2 deletions libs/uvit_multi_post_ln_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
import einops
import torch.utils.checkpoint
import torch.nn.functional as F

"""
在uvit_multi_post_ln_v1.py文件中,lora_cross_attention_itot和lora_cross_attention_ttoi是两个不同的类。
它们都是用于实现LORA(LOcality-aware Representation Adaptation)的交叉注意力模块,但是它们的输入和输出略有不同。
lora_cross_attention_itot的输入是一个形状为(B, N, C)的张量,其中B是批量大小,N是序列长度,C是特征维度。
它还接收一个形状为(B, M, C)的张量,其中M是另一个序列的长度。lora_cross_attention_itot的输出是一个形状为(B, N, C)的张量,
表示第一个序列的每个位置的更新表示。
lora_cross_attention_ttoi的输入和输出与lora_cross_attention_itot略有不同。
它的输入是一个形状为(B, M, C)的张量和一个形状为(B, N, C)的张量,其中M和N分别是两个序列的长度。
lora_cross_attention_ttoi的输出是一个形状为(B, M, C)的张量,表示第二个序列的每个位置的更新表示。
"""
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=24, network_alpha=None, device='cuda:0', dtype=None):
super().__init__()
Expand Down Expand Up @@ -522,7 +535,7 @@ def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12,

self.token_embedding = nn.Embedding(2, embed_dim)
self.pos_embed_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
def add_lora(self,Lora):
def delete_lora(self,Lora):
self.Lora = Lora

def _init_weights(self, m):
Expand Down Expand Up @@ -578,7 +591,7 @@ def forward(self, img, clip_img, text, t_img, t_text, data_type):
skips = []
count = 0
for blk in self.in_blocks:
if hasattr(self, 'Lora'):
if not hasattr(self, 'Lora'):
t_img_token, t_text_token, token_embed, text, clip_img, img = x.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
modelttoi = self.lora_adapters_ttoi[count]
modelitot = self.lora_adapters_itot[count]
Expand Down
2 changes: 1 addition & 1 deletion sample_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def main(argv=None):
# init models
nnet = UViT(**config.nnet)
Lora = True
nnet.add_lora(Lora)
nnet.delete_lora(Lora)
print(config.nnet_path)
print(f'load nnet from {config.nnet_path}')
nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'), False)
Expand Down
61 changes: 41 additions & 20 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
- 其他的参数需要选手自行设定
代码输出:
- 微调后的模型以及其他附加的子模块
accelerate launch train.py \
--instance_data_dir ="目标图像的数据集路径" \
--outdir="自己的模型输出路径"\
--class_data_dir "自己的正则数据集路径" \
--with_prior_preservation --prior_loss_weight=1.0 \
--class_prompt="girl" --num_class_images=200 \
--instance_prompt="photo of a <new1> girl" \
--modifier_token "<new1>"
"""
from accelerate import Accelerator
import hashlib
Expand Down Expand Up @@ -140,9 +149,10 @@ def train(config):


caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)

nnet, optimizer = accelerator.prepare(train_state.nnet, train_state.optimizer)
nnet.to(device)

for name,param in nnet.named_parameters():
if name.split('.')[-1] not in ['lora_adapters_ttoi'] or ['lora_adapters_itot'] : # 非Lora部分不计算梯度
param.requires_grad=False
Expand Down Expand Up @@ -281,7 +291,7 @@ def train_step():
original_shape = img4clip.shape

new_n = original_shape[0]

#img4clip的形状传出来有问题,只有修改他的形状,具体原因还不清晰
# 将张量改为新的形状 [new_n/3, 3, 224, 224]
new_shape = (new_n // 3, 3, *original_shape[1:])

Expand All @@ -295,27 +305,28 @@ def train_step():

bloss = LSimple_T2I(img=z,
clip_img=clip_img, text=text, data_type=data_type, nnet=nnet, schedule=schedule, device=device, config=config,mask=mask)

bloss.requires_grad = True

accelerator.backward(bloss)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if config.modifier_token is not None:
if accelerator.num_processes > 1:
grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
else:
grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
for i in range(len(modifier_token_id[1:])):
index_grads_to_zero = index_grads_to_zero & (
torch.arange(len(tokenizer)) != modifier_token_id[i]
)
grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[
index_grads_to_zero, :
].fill_(0)
# if config.modifier_token is not None:
# if accelerator.num_processes > 1:
# grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
# else:
# grads_text_encoder = text_encoder.get_input_embeddings().weight.grad

# # Get the index for tokens that we want to zero the grads for
# index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
# for i in range(len(modifier_token_id[1:])):
# index_grads_to_zero = index_grads_to_zero & (
# torch.arange(len(tokenizer)) != modifier_token_id[i]
# )
# grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[
# index_grads_to_zero, :
# ].fill_(0)


## 把这里的nnet换成了 nnet 还算是微调?
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(text_encoder.parameters(), nnet.parameters())
Expand All @@ -341,7 +352,7 @@ def train_step():
metrics['bloss'] = accelerator.gather(bloss.detach().mean()).mean().item()
# metrics['loss_img'] = accelerator.gather(loss_img.detach().mean()).mean().item()
# metrics['loss_clip_img'] = accelerator.gather(loss_clip_img.detach().mean()).mean().item()
metrics['scale'] = accelerator.scaler.get_scale()
# metrics['scale'] = accelerator.scaler.get_scale()
metrics['lr'] = train_state.optimizer.param_groups[0]['lr']
return metrics

Expand Down Expand Up @@ -624,4 +635,14 @@ def main():
if __name__ == "__main__":
main()



"""
accelerate launch train.py \
--instance_data_dir="/home/schengwei/competition/train_data/girl2" \
--outdir="/home/schengwei/competition/model_output/girl2"\
--class_data_dir "/home/schengwei/competition/real_reg/samples_person" \
--with_prior_preservation --prior_loss_weight=1.0 \
--class_prompt="girl" --num_class_images=200 \
--instance_prompt="photo of a <new1> girl" \
--modifier_token "<new1>"
"""

0 comments on commit b8ba952

Please sign in to comment.