Skip to content

Commit

Permalink
Support quant bert reward (#2859)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jan 4, 2025
1 parent 967d8f0 commit 581a404
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
- model_revision: 模型版本
- task_type: 默认为'causal_lm'. 可选为'causal_lm', 'seq_cls'. seq_cls的例子可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls).
- 🔥torch_dtype: 模型权重的数据类型,支持`float16`,`bfloat16`,`float32`,默认从config文件中读取
- attn_impl: attention类型,支持`flash_attn`, `sdpa`, `eager`,默认使用sdpa
- attn_impl: attention类型,可选项为`flash_attn`, `sdpa`, `eager`,默认使用sdpa。注意:这三种实现并不一定支持,这取决于对应模型的支持情况。
- num_labels: 分类模型需要指定。代表标签数量,默认为None
- rope_scaling: rope类型,支持`linear``dynamic`,请配合`max_length`共同使用
- device_map: 模型使用的device_map配置,例如:'auto'、'cpu'、json字符串、json文件路径
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The introduction to command line parameters will cover base arguments, atomic ar
- model_revision: Model version.
- 🔥torch_dtype: Data type for model weights, supports `float16`, `bfloat16`, `float32`, default is read from the config file.
- task_type: Defaults to 'causal_lm'. Options include 'causal_lm' and 'seq_cls'. You can view examples of seq_cls [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls).
- attn_impl: Attention type, supports `flash_attn`, `sdpa`, `eager`, default is sdpa.
- attn_impl: type of attention, options are `flash_attn`, `sdpa`, `eager`, with the default being `sdpa`. Note: Not all three implementations are guaranteed to be supported; it depends on the support available for the corresponding model.
- num_labels: To be specified for classification models, representing the number of labels, default is None.
- rope_scaling: Rope type, supports `linear` and `dynamic`, to be used with `max_length`.
- device_map: Configuration of the device map used by the model, e.g., 'auto', 'cpu', json string, json file path.
Expand Down
16 changes: 16 additions & 0 deletions examples/export/quantize/bert/bnb.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# merge-lora
CUDA_VISIBLE_DEVICES=0 swift export \
--adapters swift/test_bert \
--output_dir output/swift_test_bert_merged \
--merge_lora true

# bnb quantize
CUDA_VISIBLE_DEVICES=0 swift export \
--model output/swift_test_bert_merged \
--output_dir output/swift_test_bert_bnb_int4 \
--quant_bits 4 \
--quant_method bnb

# infer
CUDA_VISIBLE_DEVICES=0 swift infer \
--model output/swift_test_bert_bnb_int4
18 changes: 18 additions & 0 deletions examples/export/quantize/bert/gptq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# merge-lora
CUDA_VISIBLE_DEVICES=0 swift export \
--adapters swift/test_bert \
--output_dir output/swift_test_bert_merged \
--merge_lora true

# gptq quantize
CUDA_VISIBLE_DEVICES=0 swift export \
--model output/swift_test_bert_merged \
--load_data_args true \
--output_dir output/swift_test_bert_gptq_int4 \
--quant_bits 4 \
--quant_method gptq \
--max_length 512

# infer
CUDA_VISIBLE_DEVICES=0 swift infer \
--model output/swift_test_bert_gptq_int4
12 changes: 12 additions & 0 deletions examples/export/quantize/reward_model/bnb.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# bnb quantize
CUDA_VISIBLE_DEVICES=0 swift export \
--model Shanghai_AI_Laboratory/internlm2-1_8b-reward \
--output_dir output/internlm2-1_8b-reward-bnb-int4 \
--quant_bits 4 \
--quant_method bnb

# infer
CUDA_VISIBLE_DEVICES=0 swift infer \
--model output/internlm2-1_8b-reward-bnb-int4 \
--val_dataset 'AI-ModelScope/alpaca-gpt4-data-zh#1000' \
--max_batch_size 16
13 changes: 13 additions & 0 deletions examples/export/quantize/reward_model/gptq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# gptq quantize
CUDA_VISIBLE_DEVICES=0 swift export \
--model Shanghai_AI_Laboratory/internlm2-1_8b-reward \
--output_dir output/internlm2-1_8b-reward-gptq-int4 \
--quant_bits 4 \
--quant_method gptq \
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#1000' 'AI-ModelScope/alpaca-gpt4-data-en#1000'

# infer
CUDA_VISIBLE_DEVICES=0 swift infer \
--model output/internlm2-1_8b-reward-gptq-int4 \
--val_dataset 'AI-ModelScope/alpaca-gpt4-data-zh#1000' \
--max_batch_size 16
1 change: 0 additions & 1 deletion swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def get_template(self, processor: 'Processor') -> 'Template':
template_kwargs = self.get_template_kwargs()
template = get_template(self.template, processor, **template_kwargs)
logger.info(f'default_system: {template.template_meta.default_system}')
template.set_mode(self.task_type) # default mode
return template

def get_model_processor(self, *, model=None, model_type=None, model_revision=None, **kwargs):
Expand Down
4 changes: 1 addition & 3 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,7 @@ def _infer(
template.model = self.model

generation_config = None
if self.model_info.task_type == 'seq_cls':
template.set_mode('seq_cls')
else:
if self.model_info.task_type == 'causal_lm':
template.set_mode('pt')

max_workers = min(32, os.cpu_count(), len(infer_requests))
Expand Down
11 changes: 6 additions & 5 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
self.skip_prompt = False

self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer
'train', 'rlhf', 'kto' # train
'train', 'rlhf', 'kto', # train
'seq_cls'] = 'pt'
if self.model_info.task_type != 'causal':
self.mode = self.model_info.task_type
Expand Down Expand Up @@ -583,7 +583,7 @@ def _swift_encode(self, inputs: StdTemplateInputs):
context_list = prompt.copy()
extra_context_list = []
extra_context_type = None
if i < n_round - 1 or self.mode == 'seq_cls':
if i < n_round - 1 or self.mode == 'seq_cls' and response is not None:
# Not the last round.
context_list.append('{{RESPONSE}}')
extra_context_list = template_meta.chat_sep
Expand Down Expand Up @@ -721,9 +721,10 @@ def is_training(self):
return self.mode not in {'vllm', 'lmdeploy', 'pt'}

def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None:
if mode == 'causal_lm':
mode = 'train'
self.mode = mode
if self.model_info.task_type == 'causal_lm':
self.mode = mode
else:
swift.warning(f'task_type: `{self.model_info.task_type}` does not support modifying template.mode.')

def register_post_encode_hook(self, models: List[nn.Module]) -> None:
"""This function is important for multi-modal training, as it registers the post_encode method
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from swift.utils import get_env_args, is_deepspeed_enabled
from ..base import Template
from ..constant import LLMTemplateType, MLLMTemplateType, RMTemplateType
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import register_template
from ..template_inputs import StdTemplateInputs
from ..template_meta import TemplateMeta
from ..utils import Context, Word, findall
from ..vision_utils import load_audio_qwen, load_batch, load_file
from ..vision_utils import load_audio_qwen, load_batch
from .llama import Llama3TemplateMeta
from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta

Expand Down
2 changes: 2 additions & 0 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def _prepare_model_tokenizer(self):

def _prepare_template(self) -> None:
template = self.args.get_template(self.processor)
if self.args.task_type == 'causal_lm':
template.set_mode('train')
if template.use_model:
template.model = self.model
self.template = template
Expand Down
23 changes: 22 additions & 1 deletion tests/export/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,29 @@ def test_vlm_bnb_quant():
# infer_main(InferArguments(ckpt_dir='Qwen/Qwen2-VL-7B-Instruct-bnb-int4'))


def test_bert():
from swift.llm import export_main, ExportArguments
output_dir = 'output/swift_test_bert_merged'
export_main(ExportArguments(adapters='swift/test_bert', merge_lora=True, output_dir=output_dir))
export_main(
ExportArguments(model=output_dir, load_data_args=True, quant_bits=4, quant_method='gptq', max_length=512))


def test_reward_model():
from swift.llm import export_main, ExportArguments

export_main(
ExportArguments(
model='Shanghai_AI_Laboratory/internlm2-1_8b-reward',
dataset=['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000'],
quant_bits=4,
quant_method='gptq'))


if __name__ == '__main__':
# test_llm_quant('gptq')
# test_vlm_quant('gptq')
test_audio_quant('gptq')
# test_audio_quant('gptq')
# test_vlm_bnb_quant()
# test_bert()
test_reward_model()

0 comments on commit 581a404

Please sign in to comment.