1+ import collections
2+ import collections .abc
3+
4+ collections .Iterable = collections .abc .Iterable
5+
6+ import mindspore as ms
7+ from mindnlp .transformers import AlignModel , AlignProcessor
8+ from mindspore import Tensor , nn , ops , Parameter
9+ from PIL import Image
10+ from pycocotools .coco import COCO
11+ import os
12+ from tqdm import tqdm
13+ import pickle
14+ from concurrent .futures import ThreadPoolExecutor
15+ import numpy as np
16+
17+ HYPERPARAMS = {
18+ "model_name" : "E:/Code/align_ft_torch/cache/model/kakaobrain/align-base" ,
19+ "epochs" : 10 ,
20+ "batch_size" : 4 ,
21+ "learning_rate" : 1e-4 ,
22+ "train_samples" : 200 ,
23+ "max_length" : 128 ,
24+ "num_workers" : 8 ,
25+ "data_dir" : "MSCOCO" ,
26+ "data_type" : "val2017" ,
27+ "train_cache_file" : "mscoco_preprocessed_train_200.pkl" ,
28+ "save_dir" : "cache/model" ,
29+ "model_save_path" : "cache/model/finetuned_align_model_epoch_{epoch}.ckpt" ,
30+ "processor_save_path" : "cache/model/finetuned_align_processor"
31+ }
32+
33+ ms .set_context (mode = ms .PYNATIVE_MODE , device_target = "Ascend" )
34+ ms .context .reset_auto_parallel_context ()
35+
36+ processor = AlignProcessor .from_pretrained (HYPERPARAMS ["model_name" ], local_files_only = True )
37+ model = AlignModel .from_pretrained (HYPERPARAMS ["model_name" ], local_files_only = True )
38+ model .set_train (True )
39+
40+ print ("Model config:" , model .config )
41+ params = model .trainable_params ()
42+ print ("Number of trainable params:" , len (params ))
43+
44+
45+ def setup_coco ():
46+ dataDir = HYPERPARAMS ["data_dir" ]
47+ dataType = HYPERPARAMS ["data_type" ]
48+ os .makedirs (dataDir , exist_ok = True )
49+ os .makedirs (f"{ dataDir } /annotations" , exist_ok = True )
50+ os .makedirs (f"{ dataDir } /{ dataType } " , exist_ok = True )
51+ ann_file = f"{ dataDir } /annotations/captions_{ dataType } .json"
52+ if not os .path .exists (ann_file ):
53+ ann_zip = f"{ dataDir } /annotations_trainval2017.zip"
54+ if not os .path .exists (ann_zip ):
55+ raise FileNotFoundError (f"{ ann_zip } not found. Please download it manually." )
56+ print ("Extracting annotations..." )
57+ os .system (f"unzip -o { ann_zip } -d { dataDir } " )
58+ return dataDir , dataType
59+
60+
61+ dataDir , dataType = setup_coco ()
62+ annFile = f'{ dataDir } /annotations/captions_{ dataType } .json'
63+ coco = COCO (annFile )
64+
65+
66+ def get_image_and_caption (coco , img_id , cache_dir = f"{ HYPERPARAMS ['data_dir' ]} /{ HYPERPARAMS ['data_type' ]} " ):
67+ ann_ids = coco .getAnnIds (imgIds = img_id )
68+ anns = coco .loadAnns (ann_ids )
69+ caption = anns [0 ]['caption' ]
70+ img_info = coco .loadImgs (img_id )[0 ]
71+ img_path = f"{ cache_dir } /{ img_info ['file_name' ]} "
72+ image = Image .open (img_path )
73+ if image .mode != "RGB" :
74+ image = image .convert ("RGB" )
75+ return image , caption
76+
77+
78+ def process_sample (img_id , coco ):
79+ image , caption = get_image_and_caption (coco , img_id )
80+ inputs = processor (
81+ text = caption ,
82+ images = image ,
83+ return_tensors = "ms" ,
84+ padding = "max_length" ,
85+ max_length = HYPERPARAMS ["max_length" ]
86+ )
87+ return (inputs ["input_ids" ][0 ], inputs ["attention_mask" ][0 ], inputs ["pixel_values" ][0 ])
88+
89+
90+ def preprocess_and_save (coco , num_samples , cache_file ):
91+ if os .path .exists (cache_file ):
92+ print (f"Loading preprocessed data from { cache_file } " )
93+ with open (cache_file , "rb" ) as f :
94+ dataset = pickle .load (f )
95+ print (f"Loaded dataset size: { len (dataset )} samples" )
96+ return dataset
97+ img_ids = coco .getImgIds ()[:num_samples ]
98+ dataset = []
99+ with ThreadPoolExecutor (max_workers = HYPERPARAMS ["num_workers" ]) as executor :
100+ dataset = list (tqdm (executor .map (lambda x : process_sample (x , coco ), img_ids ),
101+ total = num_samples , desc = f"Preprocessing dataset ({ num_samples } samples)" ))
102+ with open (cache_file , "wb" ) as f :
103+ pickle .dump (dataset , f )
104+ return dataset
105+
106+
107+ def create_train_dataloader (coco , batch_size = HYPERPARAMS ["batch_size" ]):
108+ train_dataset = preprocess_and_save (coco , HYPERPARAMS ["train_samples" ], HYPERPARAMS ["train_cache_file" ])
109+ train_dataloader = ms .dataset .GeneratorDataset (
110+ train_dataset ,
111+ column_names = ["input_ids" , "attention_mask" , "pixel_values" ]
112+ ).batch (batch_size )
113+ return train_dataloader
114+
115+
116+ class TrainingNet (nn .Cell ):
117+ def __init__ (self , model ):
118+ super ().__init__ ()
119+ self .model = model
120+ self .global_pool = nn .AdaptiveAvgPool2d (1 )
121+ self .text_projection = nn .Dense (768 , 640 )
122+ self .logit_scale = Parameter (Tensor (np .log (1 / 0.07 ), dtype = ms .float32 ), requires_grad = True )
123+ self .image_embeds = None
124+ self .text_embeds = None
125+
126+ def construct (self , input_ids , attention_mask , pixel_values ):
127+ embedding_output = self .model .vision_model .embeddings (pixel_values )
128+ encoder_outputs = self .model .vision_model .encoder (embedding_output )
129+ last_hidden_state = encoder_outputs [0 ]
130+ pooled_output = self .global_pool (last_hidden_state )
131+ self .image_embeds = pooled_output .reshape (pooled_output .shape [:2 ])
132+ text_outputs = self .model .text_model (input_ids = input_ids , attention_mask = attention_mask )
133+ text_embeds = text_outputs [0 ][:, 0 , :]
134+ self .text_embeds = self .text_projection (text_embeds )
135+ logits = ops .matmul (self .image_embeds , self .text_embeds .T ) * ops .exp (self .logit_scale )
136+ labels = ops .arange (len (logits ), dtype = ms .int32 )
137+ loss_i2t = nn .CrossEntropyLoss ()(logits , labels )
138+ loss_t2i = nn .CrossEntropyLoss ()(logits .T , labels )
139+ return (loss_i2t + loss_t2i ) / 2
140+
141+
142+ def convert_to_parameter (params ):
143+ converted = []
144+ for i , param in enumerate (params ):
145+ if not isinstance (param , Parameter ):
146+ name = getattr (param , 'name' , f"param_{ i } " ) if hasattr (param , 'name' ) else f"param_{ i } "
147+ converted .append (Parameter (param .data , name = name , requires_grad = True ))
148+ else :
149+ converted .append (param )
150+ return converted
151+
152+
153+ def finetune_model (coco , model , processor ,
154+ epochs = HYPERPARAMS ["epochs" ],
155+ batch_size = HYPERPARAMS ["batch_size" ],
156+ learning_rate = HYPERPARAMS ["learning_rate" ]):
157+ train_dataloader = create_train_dataloader (coco , batch_size )
158+ print (f"Train dataloader created with batch_size={ batch_size } , samples={ HYPERPARAMS ['train_samples' ]} " )
159+
160+ params = model .trainable_params ()
161+ if not params :
162+ print ("No trainable params found, enabling all parameters." )
163+ for param in model .parameters_and_names ():
164+ param [1 ].requires_grad = True
165+ params = model .trainable_params ()
166+
167+ params = convert_to_parameter (params )
168+ print (f"Optimizer initialized with { len (params )} parameters" )
169+ net = TrainingNet (model )
170+ optimizer = nn .Adam (params + [net .text_projection .weight , net .text_projection .bias , net .logit_scale ],
171+ learning_rate = learning_rate )
172+ train_net = nn .TrainOneStepCell (net , optimizer )
173+
174+ for epoch in range (epochs ):
175+ iterator = train_dataloader .create_dict_iterator ()
176+ total_train_loss = 0
177+ steps = 0
178+ for batch in tqdm (iterator , desc = f"Epoch { epoch + 1 } /{ epochs } (Train)" ):
179+ loss = train_net (batch ["input_ids" ], batch ["attention_mask" ], batch ["pixel_values" ])
180+ total_train_loss += loss .asnumpy ()
181+ steps += 1
182+ if steps == 1 :
183+ print (f"Epoch { epoch + 1 } , Step 1 - Train Loss: { loss .asnumpy ():.4f} " )
184+ logits = ops .matmul (net .image_embeds , net .text_embeds .T ) * ops .exp (net .logit_scale )
185+ print (f"Logits sample: { logits [:2 , :2 ]} " )
186+ avg_train_loss = total_train_loss / steps
187+ print (f"Epoch { epoch + 1 } /{ epochs } , Average Train Loss: { avg_train_loss :.4f} " )
188+
189+ param_after = net .text_projection .weight .asnumpy ()
190+ if epoch == 0 :
191+ param_before = param_after .copy ()
192+ print ("Params updated:" , not np .array_equal (param_before , param_after ))
193+
194+ save_dir = HYPERPARAMS ["save_dir" ]
195+ os .makedirs (save_dir , exist_ok = True )
196+ ms .save_checkpoint (net , HYPERPARAMS ["model_save_path" ].format (epoch = epoch + 1 ))
197+
198+ processor .save_pretrained (HYPERPARAMS ["processor_save_path" ])
199+ return model
200+
201+
202+ print ("Starting model finetuning..." )
203+ finetuned_model = finetune_model (coco , model , processor )
0 commit comments