Skip to content

Commit 94981cd

Browse files
authored
【开源实习】bart模型微调 (#2026)
1 parent a5ab505 commit 94981cd

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

llm/finetune/bart/bart_finetune.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from mindspore import nn, ops, Tensor
2+
from mindspore.dataset import GeneratorDataset
3+
from mindnlp.transformers import BartForConditionalGeneration, BartTokenizer
4+
from mindnlp.engine import Trainer, TrainingArguments
5+
from datasets import load_dataset
6+
7+
import evaluate
8+
import mindspore as ms
9+
10+
11+
rouge_metric = evaluate.load("rouge")
12+
# Load dataset and tokenizer
13+
tokenizer = BartTokenizer.from_pretrained("./bart-base")
14+
15+
dataset = load_dataset("xsum", split="train")
16+
val_dataset = load_dataset("xsum", split="validation")
17+
18+
19+
def preprocess_function(examples):
20+
inputs = tokenizer(examples["document"], max_length=512,
21+
truncation=True, padding="max_length")
22+
targets = tokenizer(
23+
examples["summary"], max_length=128, truncation=True, padding="max_length")
24+
inputs["labels"] = targets["input_ids"]
25+
return inputs
26+
27+
28+
tokenized_data = dataset.map(preprocess_function, batched=True, remove_columns=[
29+
"document", "summary", "id"], num_proc=24)
30+
tokenized_val_data = val_dataset.map(preprocess_function, batched=True, remove_columns=[
31+
"document", "summary", "id"], num_proc=24)
32+
33+
34+
# Load model
35+
model = BartForConditionalGeneration.from_pretrained("./bart-base")
36+
37+
38+
def create_mindspore_dataset(data, batch_size=8):
39+
data_list = list(data)
40+
41+
def generator():
42+
for item in data_list:
43+
yield (
44+
Tensor(item["input_ids"], dtype=ms.int32),
45+
Tensor(item["attention_mask"], dtype=ms.int32),
46+
Tensor(item["labels"], dtype=ms.int32)
47+
)
48+
49+
return GeneratorDataset(generator, column_names=["input_ids", "attention_mask", "labels"]).batch(batch_size)
50+
51+
52+
def compute_metrics(pred):
53+
54+
labels_ids = pred.label_ids
55+
pred_ids = pred.predictions[0]
56+
57+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
58+
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
59+
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
60+
61+
rouge_output = rouge_metric.compute(
62+
predictions=pred_str,
63+
references=label_str,
64+
rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"],
65+
)
66+
67+
return {
68+
"R1": round(rouge_output["rouge1"], 4),
69+
"R2": round(rouge_output["rouge2"], 4),
70+
"RL": round(rouge_output["rougeL"], 4),
71+
"RLsum": round(rouge_output["rougeLsum"], 4),
72+
}
73+
74+
75+
def preprocess_logits_for_metrics(logits, labels):
76+
"""
77+
防止内存溢出
78+
"""
79+
pred_ids = ms.mint.argmax(logits[0], dim=-1)
80+
return pred_ids, labels
81+
82+
83+
train_dataset = create_mindspore_dataset(tokenized_data, batch_size=4)
84+
eval_dataset = create_mindspore_dataset(tokenized_val_data, batch_size=2)
85+
86+
training_args = TrainingArguments(
87+
output_dir="./results",
88+
evaluation_strategy="epoch",
89+
learning_rate=2e-5,
90+
per_device_train_batch_size=4,
91+
per_device_eval_batch_size=2,
92+
num_train_epochs=3,
93+
weight_decay=0.01,
94+
save_total_limit=2,
95+
)
96+
97+
trainer = Trainer(
98+
model=model,
99+
args=training_args,
100+
train_dataset=train_dataset,
101+
eval_dataset=eval_dataset,
102+
tokenizer=tokenizer,
103+
compute_metrics=compute_metrics,
104+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
105+
)
106+
107+
trainer.train()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
## bart模型微调报告
2+
3+
### 任务
4+
- **任务编号**:#IAUOXU
5+
- **任务链接**[【开源实习】bart模型微调](https://gitee.com/mindspore/community/issues/IAUOXU)
6+
- **实现内容**:实现了bart模型在XSum数据集上的微调。
7+
- **模型**`facebook/bart-base`
8+
- **数据集**`EdinburghNLP/xsum`
9+
10+
---
11+
12+
### 结果对比
13+
14+
#### **Mindnlp+D910B**
15+
16+
| Epoch | Eval Loss | R1 (ROUGE-1) | R2 (ROUGE-2) | RL (ROUGE-L) | RLsum (ROUGE-Lsum) |
17+
|------:|----------:|-------------:|-------------:|-------------:|-------------------:|
18+
| 1 | 0.4504 | 0.5265 | 0.2512 | 0.5003 | 0.5004 |
19+
| 2 | 0.4481 | 0.5272 | 0.2538 | 0.5026 | 0.5025 |
20+
| 3 | 0.4440 | 0.5316 | 0.2580 | 0.5061 | 0.5062 |
21+
22+
---
23+
24+
#### **Pytorch+3090**
25+
26+
| Epoch | Eval Loss | R1 (ROUGE-1) | R2 (ROUGE-2) | RL (ROUGE-L) | RLsum (ROUGE-Lsum) |
27+
|------:|----------:|-------------:|-------------:|-------------:|-------------------:|
28+
| 1 | 0.4364 | 0.5226 | 0.2432 | 0.4965 | 0.4961 |
29+
| 2 | 0.4297 | 0.5309 | 0.2547 | 0.5066 | 0.5065 |
30+
| 3 | 0.4290 | 0.5318 | 0.2563 | 0.5065 | 0.5062 |
31+
32+
---

0 commit comments

Comments
 (0)