Skip to content

Commit b228c1d

Browse files
authored
【开源实习】align模型微调 IAUOS5 (#1997)
1 parent 797fade commit b228c1d

File tree

2 files changed

+363
-0
lines changed

2 files changed

+363
-0
lines changed

llm/finetune/Align/eval_model.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import collections
2+
import collections.abc
3+
4+
collections.Iterable = collections.abc.Iterable
5+
6+
import mindspore as ms
7+
from mindnlp.transformers import AlignModel, AlignProcessor
8+
from mindspore import Tensor, nn, ops, Parameter
9+
from pycocotools.coco import COCO
10+
import os
11+
from tqdm import tqdm
12+
import pickle
13+
from concurrent.futures import ThreadPoolExecutor
14+
import numpy as np
15+
import gc
16+
17+
HYPERPARAMS = {
18+
"model_name": "E:/Code/align_ft_torch/cache/model/kakaobrain/align-base",
19+
"batch_size": 4,
20+
"val_samples": 50,
21+
"max_length": 128,
22+
"num_workers": 8,
23+
"data_dir": "MSCOCO",
24+
"data_type": "val2017",
25+
"val_cache_file": "mscoco_preprocessed_val_50.pkl",
26+
"save_dir": "cache/model",
27+
"model_save_path": "cache/model/finetuned_align_model_epoch_{epoch}.ckpt",
28+
"processor_save_path": "cache/model/finetuned_align_processor"
29+
}
30+
31+
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")
32+
ms.context.reset_auto_parallel_context()
33+
34+
35+
def setup_coco():
36+
dataDir = HYPERPARAMS["data_dir"]
37+
dataType = HYPERPARAMS["data_type"]
38+
os.makedirs(dataDir, exist_ok=True)
39+
os.makedirs(f"{dataDir}/annotations", exist_ok=True)
40+
os.makedirs(f"{dataDir}/{dataType}", exist_ok=True)
41+
ann_file = f"{dataDir}/annotations/captions_{dataType}.json"
42+
if not os.path.exists(ann_file):
43+
ann_zip = f"{dataDir}/annotations_trainval2017.zip"
44+
if not os.path.exists(ann_zip):
45+
raise FileNotFoundError(f"{ann_zip} not found. Please download it manually.")
46+
print("Extracting annotations...")
47+
os.system(f"unzip -o {ann_zip} -d {dataDir}")
48+
return dataDir, dataType
49+
50+
51+
dataDir, dataType = setup_coco()
52+
annFile = f'{dataDir}/annotations/captions_{dataType}.json'
53+
coco = COCO(annFile)
54+
55+
56+
def get_image_and_caption(coco, img_id, cache_dir=f"{HYPERPARAMS['data_dir']}/{HYPERPARAMS['data_type']}"):
57+
ann_ids = coco.getAnnIds(imgIds=img_id)
58+
anns = coco.loadAnns(ann_ids)
59+
caption = anns[0]['caption']
60+
img_info = coco.loadImgs(img_id)[0]
61+
img_path = f"{cache_dir}/{img_info['file_name']}"
62+
image = Image.open(img_path)
63+
if image.mode != "RGB":
64+
image = image.convert("RGB")
65+
return image, caption
66+
67+
68+
def process_sample(img_id, coco):
69+
image, caption = get_image_and_caption(coco, img_id)
70+
processor = AlignProcessor.from_pretrained(HYPERPARAMS["processor_save_path"])
71+
inputs = processor(
72+
text=caption,
73+
images=image,
74+
return_tensors="ms",
75+
padding="max_length",
76+
max_length=HYPERPARAMS["max_length"]
77+
)
78+
return (inputs["input_ids"][0], inputs["attention_mask"][0], inputs["pixel_values"][0])
79+
80+
81+
def preprocess_and_save(coco, num_samples, cache_file):
82+
if os.path.exists(cache_file):
83+
print(f"Loading preprocessed data from {cache_file}")
84+
with open(cache_file, "rb") as f:
85+
dataset = pickle.load(f)
86+
print(f"Loaded dataset size: {len(dataset)} samples")
87+
return dataset
88+
img_ids = coco.getImgIds()[:num_samples]
89+
dataset = []
90+
with ThreadPoolExecutor(max_workers=HYPERPARAMS["num_workers"]) as executor:
91+
dataset = list(tqdm(executor.map(lambda x: process_sample(x, coco), img_ids),
92+
total=num_samples, desc=f"Preprocessing dataset ({num_samples} samples)"))
93+
with open(cache_file, "wb") as f:
94+
pickle.dump(dataset, f)
95+
return dataset
96+
97+
98+
def create_val_dataloader(coco, batch_size=HYPERPARAMS["batch_size"]):
99+
val_dataset = preprocess_and_save(coco, HYPERPARAMS["val_samples"], HYPERPARAMS["val_cache_file"])
100+
val_dataloader = ms.dataset.GeneratorDataset(
101+
val_dataset,
102+
column_names=["input_ids", "attention_mask", "pixel_values"]
103+
).batch(batch_size)
104+
return val_dataloader
105+
106+
107+
class TrainingNet(nn.Cell):
108+
def __init__(self, model):
109+
super().__init__()
110+
self.model = model
111+
self.global_pool = nn.AdaptiveAvgPool2d(1)
112+
self.text_projection = nn.Dense(768, 640)
113+
self.logit_scale = Parameter(Tensor(np.log(1 / 0.07), dtype=ms.float32), requires_grad=True)
114+
self.image_embeds = None
115+
self.text_embeds = None
116+
117+
def construct(self, input_ids, attention_mask, pixel_values):
118+
embedding_output = self.model.vision_model.embeddings(pixel_values)
119+
encoder_outputs = self.model.vision_model.encoder(embedding_output)
120+
last_hidden_state = encoder_outputs[0]
121+
pooled_output = self.global_pool(last_hidden_state)
122+
self.image_embeds = pooled_output.reshape(pooled_output.shape[:2])
123+
text_outputs = self.model.text_model(input_ids=input_ids, attention_mask=attention_mask)
124+
text_embeds = text_outputs[0][:, 0, :]
125+
self.text_embeds = self.text_projection(text_embeds)
126+
logits = ops.matmul(self.image_embeds, self.text_embeds.T) * ops.exp(self.logit_scale)
127+
labels = ops.arange(len(logits), dtype=ms.int32)
128+
loss_i2t = nn.CrossEntropyLoss()(logits, labels)
129+
loss_t2i = nn.CrossEntropyLoss()(logits.T, labels)
130+
return (loss_i2t + loss_t2i) / 2
131+
132+
133+
def evaluate_model(coco, epoch_to_eval):
134+
processor = AlignProcessor.from_pretrained(HYPERPARAMS["processor_save_path"])
135+
model = AlignModel.from_pretrained(HYPERPARAMS["model_name"], local_files_only=True)
136+
net = TrainingNet(model) # 使用 TrainingNet 包装 AlignModel
137+
param_dict = ms.load_checkpoint(HYPERPARAMS["model_save_path"].format(epoch=epoch_to_eval))
138+
ms.load_param_into_net(net, param_dict) # 加载到 TrainingNet
139+
net.set_train(False)
140+
141+
val_dataloader = create_val_dataloader(coco)
142+
print(f"Val dataloader created with batch_size={HYPERPARAMS['batch_size']}, samples={HYPERPARAMS['val_samples']}")
143+
144+
total_val_loss = 0
145+
val_steps = 0
146+
for batch in tqdm(val_dataloader.create_dict_iterator(), desc=f"Evaluating Epoch {epoch_to_eval}"):
147+
loss = net(batch["input_ids"], batch["attention_mask"], batch["pixel_values"])
148+
total_val_loss += loss.asnumpy()
149+
val_steps += 1
150+
avg_val_loss = total_val_loss / val_steps
151+
print(f"Epoch {epoch_to_eval}, Eval Loss: {avg_val_loss:.4f}")
152+
153+
gc.collect()
154+
return avg_val_loss
155+
156+
157+
if __name__ == "__main__":
158+
print("Starting model evaluation...")
159+
for epoch in range(1, 11):
160+
evaluate_model(coco, epoch)

llm/finetune/Align/finetune.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import collections
2+
import collections.abc
3+
4+
collections.Iterable = collections.abc.Iterable
5+
6+
import mindspore as ms
7+
from mindnlp.transformers import AlignModel, AlignProcessor
8+
from mindspore import Tensor, nn, ops, Parameter
9+
from PIL import Image
10+
from pycocotools.coco import COCO
11+
import os
12+
from tqdm import tqdm
13+
import pickle
14+
from concurrent.futures import ThreadPoolExecutor
15+
import numpy as np
16+
17+
HYPERPARAMS = {
18+
"model_name": "E:/Code/align_ft_torch/cache/model/kakaobrain/align-base",
19+
"epochs": 10,
20+
"batch_size": 4,
21+
"learning_rate": 1e-4,
22+
"train_samples": 200,
23+
"max_length": 128,
24+
"num_workers": 8,
25+
"data_dir": "MSCOCO",
26+
"data_type": "val2017",
27+
"train_cache_file": "mscoco_preprocessed_train_200.pkl",
28+
"save_dir": "cache/model",
29+
"model_save_path": "cache/model/finetuned_align_model_epoch_{epoch}.ckpt",
30+
"processor_save_path": "cache/model/finetuned_align_processor"
31+
}
32+
33+
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")
34+
ms.context.reset_auto_parallel_context()
35+
36+
processor = AlignProcessor.from_pretrained(HYPERPARAMS["model_name"], local_files_only=True)
37+
model = AlignModel.from_pretrained(HYPERPARAMS["model_name"], local_files_only=True)
38+
model.set_train(True)
39+
40+
print("Model config:", model.config)
41+
params = model.trainable_params()
42+
print("Number of trainable params:", len(params))
43+
44+
45+
def setup_coco():
46+
dataDir = HYPERPARAMS["data_dir"]
47+
dataType = HYPERPARAMS["data_type"]
48+
os.makedirs(dataDir, exist_ok=True)
49+
os.makedirs(f"{dataDir}/annotations", exist_ok=True)
50+
os.makedirs(f"{dataDir}/{dataType}", exist_ok=True)
51+
ann_file = f"{dataDir}/annotations/captions_{dataType}.json"
52+
if not os.path.exists(ann_file):
53+
ann_zip = f"{dataDir}/annotations_trainval2017.zip"
54+
if not os.path.exists(ann_zip):
55+
raise FileNotFoundError(f"{ann_zip} not found. Please download it manually.")
56+
print("Extracting annotations...")
57+
os.system(f"unzip -o {ann_zip} -d {dataDir}")
58+
return dataDir, dataType
59+
60+
61+
dataDir, dataType = setup_coco()
62+
annFile = f'{dataDir}/annotations/captions_{dataType}.json'
63+
coco = COCO(annFile)
64+
65+
66+
def get_image_and_caption(coco, img_id, cache_dir=f"{HYPERPARAMS['data_dir']}/{HYPERPARAMS['data_type']}"):
67+
ann_ids = coco.getAnnIds(imgIds=img_id)
68+
anns = coco.loadAnns(ann_ids)
69+
caption = anns[0]['caption']
70+
img_info = coco.loadImgs(img_id)[0]
71+
img_path = f"{cache_dir}/{img_info['file_name']}"
72+
image = Image.open(img_path)
73+
if image.mode != "RGB":
74+
image = image.convert("RGB")
75+
return image, caption
76+
77+
78+
def process_sample(img_id, coco):
79+
image, caption = get_image_and_caption(coco, img_id)
80+
inputs = processor(
81+
text=caption,
82+
images=image,
83+
return_tensors="ms",
84+
padding="max_length",
85+
max_length=HYPERPARAMS["max_length"]
86+
)
87+
return (inputs["input_ids"][0], inputs["attention_mask"][0], inputs["pixel_values"][0])
88+
89+
90+
def preprocess_and_save(coco, num_samples, cache_file):
91+
if os.path.exists(cache_file):
92+
print(f"Loading preprocessed data from {cache_file}")
93+
with open(cache_file, "rb") as f:
94+
dataset = pickle.load(f)
95+
print(f"Loaded dataset size: {len(dataset)} samples")
96+
return dataset
97+
img_ids = coco.getImgIds()[:num_samples]
98+
dataset = []
99+
with ThreadPoolExecutor(max_workers=HYPERPARAMS["num_workers"]) as executor:
100+
dataset = list(tqdm(executor.map(lambda x: process_sample(x, coco), img_ids),
101+
total=num_samples, desc=f"Preprocessing dataset ({num_samples} samples)"))
102+
with open(cache_file, "wb") as f:
103+
pickle.dump(dataset, f)
104+
return dataset
105+
106+
107+
def create_train_dataloader(coco, batch_size=HYPERPARAMS["batch_size"]):
108+
train_dataset = preprocess_and_save(coco, HYPERPARAMS["train_samples"], HYPERPARAMS["train_cache_file"])
109+
train_dataloader = ms.dataset.GeneratorDataset(
110+
train_dataset,
111+
column_names=["input_ids", "attention_mask", "pixel_values"]
112+
).batch(batch_size)
113+
return train_dataloader
114+
115+
116+
class TrainingNet(nn.Cell):
117+
def __init__(self, model):
118+
super().__init__()
119+
self.model = model
120+
self.global_pool = nn.AdaptiveAvgPool2d(1)
121+
self.text_projection = nn.Dense(768, 640)
122+
self.logit_scale = Parameter(Tensor(np.log(1 / 0.07), dtype=ms.float32), requires_grad=True)
123+
self.image_embeds = None
124+
self.text_embeds = None
125+
126+
def construct(self, input_ids, attention_mask, pixel_values):
127+
embedding_output = self.model.vision_model.embeddings(pixel_values)
128+
encoder_outputs = self.model.vision_model.encoder(embedding_output)
129+
last_hidden_state = encoder_outputs[0]
130+
pooled_output = self.global_pool(last_hidden_state)
131+
self.image_embeds = pooled_output.reshape(pooled_output.shape[:2])
132+
text_outputs = self.model.text_model(input_ids=input_ids, attention_mask=attention_mask)
133+
text_embeds = text_outputs[0][:, 0, :]
134+
self.text_embeds = self.text_projection(text_embeds)
135+
logits = ops.matmul(self.image_embeds, self.text_embeds.T) * ops.exp(self.logit_scale)
136+
labels = ops.arange(len(logits), dtype=ms.int32)
137+
loss_i2t = nn.CrossEntropyLoss()(logits, labels)
138+
loss_t2i = nn.CrossEntropyLoss()(logits.T, labels)
139+
return (loss_i2t + loss_t2i) / 2
140+
141+
142+
def convert_to_parameter(params):
143+
converted = []
144+
for i, param in enumerate(params):
145+
if not isinstance(param, Parameter):
146+
name = getattr(param, 'name', f"param_{i}") if hasattr(param, 'name') else f"param_{i}"
147+
converted.append(Parameter(param.data, name=name, requires_grad=True))
148+
else:
149+
converted.append(param)
150+
return converted
151+
152+
153+
def finetune_model(coco, model, processor,
154+
epochs=HYPERPARAMS["epochs"],
155+
batch_size=HYPERPARAMS["batch_size"],
156+
learning_rate=HYPERPARAMS["learning_rate"]):
157+
train_dataloader = create_train_dataloader(coco, batch_size)
158+
print(f"Train dataloader created with batch_size={batch_size}, samples={HYPERPARAMS['train_samples']}")
159+
160+
params = model.trainable_params()
161+
if not params:
162+
print("No trainable params found, enabling all parameters.")
163+
for param in model.parameters_and_names():
164+
param[1].requires_grad = True
165+
params = model.trainable_params()
166+
167+
params = convert_to_parameter(params)
168+
print(f"Optimizer initialized with {len(params)} parameters")
169+
net = TrainingNet(model)
170+
optimizer = nn.Adam(params + [net.text_projection.weight, net.text_projection.bias, net.logit_scale],
171+
learning_rate=learning_rate)
172+
train_net = nn.TrainOneStepCell(net, optimizer)
173+
174+
for epoch in range(epochs):
175+
iterator = train_dataloader.create_dict_iterator()
176+
total_train_loss = 0
177+
steps = 0
178+
for batch in tqdm(iterator, desc=f"Epoch {epoch + 1}/{epochs} (Train)"):
179+
loss = train_net(batch["input_ids"], batch["attention_mask"], batch["pixel_values"])
180+
total_train_loss += loss.asnumpy()
181+
steps += 1
182+
if steps == 1:
183+
print(f"Epoch {epoch + 1}, Step 1 - Train Loss: {loss.asnumpy():.4f}")
184+
logits = ops.matmul(net.image_embeds, net.text_embeds.T) * ops.exp(net.logit_scale)
185+
print(f"Logits sample: {logits[:2, :2]}")
186+
avg_train_loss = total_train_loss / steps
187+
print(f"Epoch {epoch + 1}/{epochs}, Average Train Loss: {avg_train_loss:.4f}")
188+
189+
param_after = net.text_projection.weight.asnumpy()
190+
if epoch == 0:
191+
param_before = param_after.copy()
192+
print("Params updated:", not np.array_equal(param_before, param_after))
193+
194+
save_dir = HYPERPARAMS["save_dir"]
195+
os.makedirs(save_dir, exist_ok=True)
196+
ms.save_checkpoint(net, HYPERPARAMS["model_save_path"].format(epoch=epoch + 1))
197+
198+
processor.save_pretrained(HYPERPARAMS["processor_save_path"])
199+
return model
200+
201+
202+
print("Starting model finetuning...")
203+
finetuned_model = finetune_model(coco, model, processor)

0 commit comments

Comments
 (0)