Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#227 from Anhforth/opt_prompt
Browse files Browse the repository at this point in the history
Optimize prompt
  • Loading branch information
ftgreat authored Feb 16, 2023
2 parents 509a774 + 4fa2c16 commit 0ee9841
Show file tree
Hide file tree
Showing 18 changed files with 493 additions and 191 deletions.
9 changes: 5 additions & 4 deletions examples/AltDiffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ prompt = "Anime portrait of natalie portman as an anime girl by stanley artgerm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


loader = AutoLoader(task_name="text2img", #contrastive learning
loader = AutoLoader(task_name="text2img",
model_name="AltDiffusion-m9",
model_dir="./checkpoints")
model_dir="./checkpoints",
use_fp16=False) # Fp16 mode

model = loader.get_model()
model.eval()
Expand Down Expand Up @@ -97,9 +98,9 @@ More parameters of predict_generate_images for you to adjust for `predict_genera
| C | int | 图片的channel数; Numeber of channels of generated images |
| seed | int | 随机种子; Random seed number |

注意:模型推理要求一张至少10G以上的GPU
注意:模型推理要求一张至少14G以上的GPU, FP16模式下则至少11G

Note that the model inference requires a GPU of at least 10G above.
Note that the model inference requires a GPU of at least 14G, and at least 11G for FP16 mode.


# 更多生成结果/More Results
Expand Down
59 changes: 59 additions & 0 deletions examples/glm_custom_pvp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Custom prompt-verbalizer pair(PVP)

## 1. Define your own prompt-verbalizer patterns
We provide api for users to create their own function to construct prompt-verbalizer patterns. Here is an example below:
```python
class RtePVP(PVP):
# Verbalizer convert original labels to more meaningful ones
VERBALIZER = {"not_entailment": [" No"], "entailment": [" Yes"]}

@staticmethod
def available_patterns():
return [0, 1, 2]

@property
def spell_length(self):
return self.num_prompt_tokens + self.prefix_prompt

def get_parts(self, example: InputExample):
"""
Construct patterns with input texts and mask, "None" here stands for places to insert continuous prompt tokens
"""
text_a = example.text_a
text_b = example.text_b.rstrip(string.punctuation)
if self.pattern_id == 0:
parts_a, parts_b = [None, '"',
self.shortenable(text_b), '" ?'], [
None, [self.mask], ',', None, ' "',
self.shortenable(text_a), '"'
]
elif self.pattern_id == 1:
parts_a, parts_b = [None, self.shortenable(text_b), '?'], [
None, [self.mask], ',', None,
self.shortenable(" " + text_a)
]
elif self.pattern_id == 2:
parts_a, parts_b = [
None,
self.shortenable(text_a), None, ' question:',
self.shortenable(" " + text_b), ' True or False?', None,
' answer:', [self.mask]
], []
else:
raise NotImplementedError(self.pattern_id)
parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
return parts_a, parts_b

def verbalize(self, label) -> List[str]:
if self.pattern_id == 4:
return [' true'] if label == 'entailment' else [' false']
return RtePVP.VERBALIZER[label]
```

## 2. Pass the user-defined class to the collate function
```python
collate_fn = ConstructSuperglueStrategy(cl_args,
tokenizer,
task_name=task_name,
custom_pvp=RtePVP)
```
119 changes: 119 additions & 0 deletions examples/glm_custom_pvp/train_custom_pvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.trainer import Trainer
from flagai.model.glm_model import GLMForSequenceClassification
from flagai.model.glm_model import GLMForSingleTokenCloze
from flagai.data.tokenizer import Tokenizer

from flagai.data.dataset import SuperGlueDataset
from flagai.test_utils import CollateArguments
from flagai.data.dataset.superglue.control import DEFAULT_METRICS, MULTI_TOKEN_TASKS, CH_TASKS
from flagai.data.dataset import ConstructSuperglueStrategy
from flagai.data.dataset.superglue.pvp import PVP
from flagai.data.dataset.data_utils import build_input_from_ids, build_sample, InputExample
from flagai.data.dataset.data_utils import build_decoder_input, build_decoder_sample, num_special_tokens_to_add
from typing import Tuple, List, Union, Dict
import string

class RtePVP(PVP):
VERBALIZER = {"not_entailment": [" No"], "entailment": [" Yes"]}

@staticmethod
def available_patterns():
return [0, 1, 2, 3, 4]

@property
def spell_length(self):
return self.num_prompt_tokens + self.prefix_prompt

def get_parts(self, example: InputExample):
# switch text_a and text_b to get the correct order
text_a = example.text_a
text_b = example.text_b.rstrip(string.punctuation)
if self.pattern_id == 0:
parts_a, parts_b = [None, '"',
self.shortenable(text_b), '" ?'], [
None, [self.mask], ',', None, ' "',
self.shortenable(text_a), '"'
]
elif self.pattern_id == 1:
parts_a, parts_b = [None, self.shortenable(text_b), '?'], [
None, [self.mask], ',', None,
self.shortenable(" " + text_a)
]
elif self.pattern_id == 2:
parts_a, parts_b = [
None,
self.shortenable(text_a), None, ' question:',
self.shortenable(" " + text_b), ' True or False?', None,
' answer:', [self.mask]
], []
else:
raise NotImplementedError(self.pattern_id)
parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
return parts_a, parts_b

def verbalize(self, label) -> List[str]:
if self.pattern_id == 4:
return [' true'] if label == 'entailment' else [' false']
return RtePVP.VERBALIZER[label]


# task_name options: ['boolq', 'cb', 'copa', 'multirc', 'rte', 'wic', 'wsc', 'afqmc', 'tnews']
task_name = "rte"

trainer = Trainer(env_type='pytorch',
epochs=10,
batch_size=4,
eval_interval=100,
log_interval=50,
experiment_name='glm_large',
pytorch_device='cuda',
load_dir=None,
lr=1e-4)
print("downloading...")

cl_args = CollateArguments()
cl_args.cloze_eval = True
cl_args.multi_token = task_name in MULTI_TOKEN_TASKS

cl_args.continuous_prompt = True
cl_args.prefix_prompt = 2
cl_args.num_prompt_tokens = 5

if task_name in CH_TASKS:
model_name = 'GLM-large-ch'
add_block_symbols=True,
else:
model_name = 'GLM-large-en'
tokenizer = Tokenizer.from_pretrained(model_name)

# model = GLMForSequenceClassification.from_pretrain(model_name=model_name, spell_length=2,
# class_num=3, tune_prefix_layers=1)

model = GLMForSingleTokenCloze.from_pretrain(download_path="./checkpoints",
model_name=model_name, spell_length=2,
class_num=3, tune_prefix_layers=1)
train_dataset = SuperGlueDataset(task_name=task_name,
data_dir='./datasets/',
dataset_type='train',
tokenizer=tokenizer)

collate_fn = ConstructSuperglueStrategy(cl_args,
tokenizer,
task_name=task_name,
custom_pvp=RtePVP)

valid_dataset = SuperGlueDataset(task_name=task_name,
data_dir='./datasets/',
dataset_type='dev',
tokenizer=tokenizer)

metric_methods = DEFAULT_METRICS[task_name]
trainer.train(model,
collate_fn=collate_fn,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
metric_methods=metric_methods)

23 changes: 23 additions & 0 deletions examples/glm_ptuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# P-tuning

Here is an example to train a model with continuous prompt (P-tuning).

## 1. Change the parameters in config
```python
cl_args.continuous_prompt = True # Enable continuous prompt
cl_args.prefix_prompt = 2 # Number of continuous prompt at the beginning
cl_args.num_prompt_tokens = 5 # Number of continuous prompt in the content
```


## 2. Change model parameters

```python
# spell_length is the final number of continuous prompt tokens in an instance, it is usually determined by the PVP structure
# tune_prefix_layers is the number of transformer layers to tune, where the rest layers are frozen
model = GLMForSingleTokenCloze.from_pretrain(download_path="./checkpoints",
model_name=model_name, spell_length=8,
tune_prefix_layers=1)
```

In such way, p-tuning can be enabled in training.
48 changes: 48 additions & 0 deletions examples/glm_ptuning/deepspeed.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"train_micro_batch_size_per_gpu": 456,
"gradient_accumulation_steps": 100,
"steps_per_print": 100,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": false,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e7,
"allgather_bucket_size": 5e7,
"cpu_offload": true
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 2000
}
},
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5,
"weight_decay": 0.1,
"betas": [
0.9,
0.98
],
"eps": 1e-6
}
},
"activation_checkpointing": {
"partition_activations": true,
"contiguous_memory_optimization": false
},
"wall_clock_breakdown": false
}
85 changes: 85 additions & 0 deletions examples/glm_ptuning/train_ptuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
from flagai.trainer import Trainer
from flagai.model.glm_model import GLMForSequenceClassification
from flagai.model.glm_model import GLMForSingleTokenCloze
from flagai.data.tokenizer import Tokenizer
from flagai.data.dataset import SuperGlueDataset
from flagai.test_utils import CollateArguments
from flagai.data.dataset.superglue.control import DEFAULT_METRICS, MULTI_TOKEN_TASKS, CH_TASKS
from flagai.data.dataset import ConstructSuperglueStrategy


# task_name options: ['boolq', 'cb', 'copa', 'multirc', 'rte', 'wic', 'wsc', 'afqmc', 'tnews']
task_name = "cb"

cl_args = CollateArguments()
cl_args.multi_token = task_name in MULTI_TOKEN_TASKS
cl_args.continuous_prompt = True
cl_args.prefix_prompt = 2
cl_args.num_prompt_tokens = 5
if task_name in CH_TASKS:
model_name = 'GLM-large-ch'
add_block_symbols=True,
else:
model_name = 'GLM-large-en'
tokenizer = Tokenizer.from_pretrained(model_name)

model = GLMForSingleTokenCloze.from_pretrain(download_path="./checkpoints",
model_name=model_name, spell_length=8,
tune_prefix_layers=1)
# model_save_path = "/home/yanzhaodong/anhforth/test/FlagAI/examples/glm_superglue/checkpoints/20000_save/pytorch_model.bin"
# model.load_state_dict(torch.load(model_save_path, map_location="cuda")["module"])
train_dataset = SuperGlueDataset(task_name=task_name,
data_dir='./datasets/',
dataset_type='train',
tokenizer=tokenizer)

collate_fn = ConstructSuperglueStrategy(cl_args,
tokenizer,
task_name=task_name)

valid_dataset = SuperGlueDataset(task_name=task_name,
data_dir='./datasets/',
dataset_type='dev',
tokenizer=tokenizer)

metric_methods = DEFAULT_METRICS[task_name]

# Deepspeed parallel trainer
trainer = Trainer(env_type='deepspeed',
epochs=10000000,
batch_size=16,
gradient_accumulation_steps=5,
checkpoint_activations=True,
eval_interval=False,
log_interval=100,
fp16=True,
save_interval=10000,
experiment_name='glm_large',
load_dir=None,
num_nodes=1,
num_gpus=2,
hostfile='./hostfile',
deepspeed_config='./deepspeed.json',
lr=1e-4,
training_script=__file__)
# Single-GPU trainer
# trainer = Trainer(env_type='pytorch',
# epochs=100,
# batch_size=1,
# eval_interval=100,
# log_interval=50,
# experiment_name='glm_large',
# pytorch_device='cuda',
# load_dir=None,
# lr=1e-4)

trainer.train(model,
collate_fn=collate_fn,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
metric_methods=metric_methods)

Loading

0 comments on commit 0ee9841

Please sign in to comment.