diff --git a/.gitignore b/.gitignore index 8901ae2..69bb724 100644 --- a/.gitignore +++ b/.gitignore @@ -4,14 +4,14 @@ models/ logs/ log*/ *__pycache__ -model_output/ +model_output/* outputs/ model_output_test/ output_test/ -real_reg/ +real_reg/* -train_data/ +train_data/* ## weights anf lib diff --git a/libs/schedule.py b/libs/schedule.py index bd1446c..e054b43 100644 --- a/libs/schedule.py +++ b/libs/schedule.py @@ -89,30 +89,28 @@ def LSimple_T2I(img, clip_img, text, data_type, nnet, schedule, device, config, img_n, clip_img_n = xn n = n.to(device) clip_img_n=clip_img_n.to(torch.float32) + t_text=torch.zeros_like(n, device=device) + data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type + img_out, clip_img_out, text_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=t_text, data_type=data_type) - print(clip_img_n.dtype) - print(img_n.dtype) - - img_out, clip_img_out, text_out = nnet(img_n, clip_img_n, text, t_img=n, t_text=torch.zeros_like(n, device=device), data_type=data_type) - - + img_out, img_out_prior = torch.chunk(img_out, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) mask = torch.chunk(mask, 2, dim=0)[0] # Compute instance loss - aloss = F.mse_loss(img_out.float(), target.float(), reduction="none") - aloss = ((aloss*mask).sum([1,2,3]) / mask.sum([1,2,3])).mean() - + aloss = F.mse_loss(img_out.float(), target.float(), reduction="mean") + + # aloss = ((aloss*mask).sum([1,2,3]) / mask.sum([1,2,3])).mean() + # Compute prior loss prior_loss = F.mse_loss(img_out_prior.float(), target_prior.float(), reduction="mean") - - bloss = aloss + config.prior_loss_weight * prior_loss - - return bloss, 0. + bloss = aloss + config.prior_loss_weight * prior_loss + + return bloss diff --git a/libs/uvit_multi_post_ln_v1.py b/libs/uvit_multi_post_ln_v1.py index 0618c62..ec214c4 100644 --- a/libs/uvit_multi_post_ln_v1.py +++ b/libs/uvit_multi_post_ln_v1.py @@ -5,6 +5,19 @@ import einops import torch.utils.checkpoint import torch.nn.functional as F + +""" +在uvit_multi_post_ln_v1.py文件中,lora_cross_attention_itot和lora_cross_attention_ttoi是两个不同的类。 +它们都是用于实现LORA(LOcality-aware Representation Adaptation)的交叉注意力模块,但是它们的输入和输出略有不同。 + +lora_cross_attention_itot的输入是一个形状为(B, N, C)的张量,其中B是批量大小,N是序列长度,C是特征维度。 +它还接收一个形状为(B, M, C)的张量,其中M是另一个序列的长度。lora_cross_attention_itot的输出是一个形状为(B, N, C)的张量, +表示第一个序列的每个位置的更新表示。 + +lora_cross_attention_ttoi的输入和输出与lora_cross_attention_itot略有不同。 +它的输入是一个形状为(B, M, C)的张量和一个形状为(B, N, C)的张量,其中M和N分别是两个序列的长度。 +lora_cross_attention_ttoi的输出是一个形状为(B, M, C)的张量,表示第二个序列的每个位置的更新表示。 +""" class LoRALinearLayer(nn.Module): def __init__(self, in_features, out_features, rank=24, network_alpha=None, device='cuda:0', dtype=None): super().__init__() @@ -522,7 +535,7 @@ def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12, self.token_embedding = nn.Embedding(2, embed_dim) self.pos_embed_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - def add_lora(self,Lora): + def delete_lora(self,Lora): self.Lora = Lora def _init_weights(self, m): @@ -578,7 +591,7 @@ def forward(self, img, clip_img, text, t_img, t_text, data_type): skips = [] count = 0 for blk in self.in_blocks: - if hasattr(self, 'Lora'): + if not hasattr(self, 'Lora'): t_img_token, t_text_token, token_embed, text, clip_img, img = x.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) modelttoi = self.lora_adapters_ttoi[count] modelitot = self.lora_adapters_itot[count] diff --git a/sample_bench.py b/sample_bench.py index 370c8cc..322b0eb 100644 --- a/sample_bench.py +++ b/sample_bench.py @@ -357,7 +357,7 @@ def main(argv=None): # init models nnet = UViT(**config.nnet) Lora = True - nnet.add_lora(Lora) + nnet.delete_lora(Lora) print(config.nnet_path) print(f'load nnet from {config.nnet_path}') nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'), False) diff --git a/train.py b/train.py index a06a244..d46ed16 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,15 @@ - 其他的参数需要选手自行设定 代码输出: - 微调后的模型以及其他附加的子模块 + +accelerate launch train.py \ + --instance_data_dir ="目标图像的数据集路径" \ + --outdir="自己的模型输出路径"\ + --class_data_dir "自己的正则数据集路径" \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --class_prompt="girl" --num_class_images=200 \ + --instance_prompt="photo of a girl" \ + --modifier_token "" """ from accelerate import Accelerator import hashlib @@ -140,9 +149,10 @@ def train(config): caption_decoder = CaptionDecoder(device=device, **config.caption_decoder) - + nnet, optimizer = accelerator.prepare(train_state.nnet, train_state.optimizer) nnet.to(device) + for name,param in nnet.named_parameters(): if name.split('.')[-1] not in ['lora_adapters_ttoi'] or ['lora_adapters_itot'] : # 非Lora部分不计算梯度 param.requires_grad=False @@ -281,7 +291,7 @@ def train_step(): original_shape = img4clip.shape new_n = original_shape[0] - + #img4clip的形状传出来有问题,只有修改他的形状,具体原因还不清晰 # 将张量改为新的形状 [new_n/3, 3, 224, 224] new_shape = (new_n // 3, 3, *original_shape[1:]) @@ -295,27 +305,28 @@ def train_step(): bloss = LSimple_T2I(img=z, clip_img=clip_img, text=text, data_type=data_type, nnet=nnet, schedule=schedule, device=device, config=config,mask=mask) - + bloss.requires_grad = True accelerator.backward(bloss) # Zero out the gradients for all token embeddings except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings - if config.modifier_token is not None: - if accelerator.num_processes > 1: - grads_text_encoder = text_encoder.get_input_embeddings().weight.grad - else: - grads_text_encoder = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0] - for i in range(len(modifier_token_id[1:])): - index_grads_to_zero = index_grads_to_zero & ( - torch.arange(len(tokenizer)) != modifier_token_id[i] - ) - grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[ - index_grads_to_zero, : - ].fill_(0) + # if config.modifier_token is not None: + # if accelerator.num_processes > 1: + # grads_text_encoder = text_encoder.get_input_embeddings().weight.grad + # else: + # grads_text_encoder = text_encoder.get_input_embeddings().weight.grad + + # # Get the index for tokens that we want to zero the grads for + # index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0] + # for i in range(len(modifier_token_id[1:])): + # index_grads_to_zero = index_grads_to_zero & ( + # torch.arange(len(tokenizer)) != modifier_token_id[i] + # ) + # grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[ + # index_grads_to_zero, : + # ].fill_(0) + -## 把这里的nnet换成了 nnet 还算是微调? if accelerator.sync_gradients: params_to_clip = ( itertools.chain(text_encoder.parameters(), nnet.parameters()) @@ -341,7 +352,7 @@ def train_step(): metrics['bloss'] = accelerator.gather(bloss.detach().mean()).mean().item() # metrics['loss_img'] = accelerator.gather(loss_img.detach().mean()).mean().item() # metrics['loss_clip_img'] = accelerator.gather(loss_clip_img.detach().mean()).mean().item() - metrics['scale'] = accelerator.scaler.get_scale() + # metrics['scale'] = accelerator.scaler.get_scale() metrics['lr'] = train_state.optimizer.param_groups[0]['lr'] return metrics @@ -624,4 +635,14 @@ def main(): if __name__ == "__main__": main() - \ No newline at end of file + +""" +accelerate launch train.py \ + --instance_data_dir="/home/schengwei/competition/train_data/girl2" \ + --outdir="/home/schengwei/competition/model_output/girl2"\ + --class_data_dir "/home/schengwei/competition/real_reg/samples_person" \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --class_prompt="girl" --num_class_images=200 \ + --instance_prompt="photo of a girl" \ + --modifier_token "" +""" \ No newline at end of file