Skip to content

Commit

Permalink
Fix qwen2 vl batch size (modelscope#2239)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Oct 14, 2024
1 parent 08edeb8 commit 5d64d39
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ RLHF参数继承了sft参数, 除此之外增加了以下参数:
- `--bnb_4bit_quant_type`: 默认值为`'nf4'`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效.
- `--bnb_4bit_use_double_quant`: 默认值为`True`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效.
- `--bnb_4bit_quant_storage`: 默认值为`True`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效.
- `--🔥max_new_tokens`: 生成新token的最大数量, 默认值为`2048`.
- `--🔥max_new_tokens`: 生成新token的最大数量, 默认值为`2048`. 如果使用部署, 请通过在客户端传入`max_tokens`来控制最大生成的tokens数.
- `--🔥do_sample`: 参考文档: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). 默认值为`None`, 继承模型的generation_config.
- `--temperature`: 默认值为`None`, 继承模型的generation_config. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
- `--top_k`: 默认值为`None`, 继承模型的generation_config. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用.
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ RLHF parameters are an extension of the sft parameters, with the addition of the
- `--bnb_4bit_quant_type`: Default is `'nf4'`. See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect.
- `--bnb_4bit_use_double_quant`: Default is `True`. See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect.
- `--bnb_4bit_quant_storage`: Default value `None`.See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect.
- `--🔥max_new_tokens`: Maximum number of new tokens to generate, default is `2048`.
- `--🔥max_new_tokens`: Maximum number of new tokens to generate, default is `2048`. If using deployment, please control the maximum number of generated tokens by passing `max_tokens` in the client.
- `--🔥do_sample`: Reference document: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). Default is `None`, inheriting the model's generation_config.
- `--temperature`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
- `--top_k`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters.
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_block_name_to_quantize(model: nn.Module, model_type: str) -> Optional[st
if module_lists:
module_list = max(module_lists, key=lambda x: len(x[1]))
_patch_model_forward(module_list[1])
return f'{prefix}.{module_list[0]}'
return f'{prefix}.{module_list[0]}' if prefix else module_list[0]


def gptq_model_quantize(model, tokenizer, batch_size):
Expand Down
66 changes: 44 additions & 22 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,8 @@ def _process_image_qwen(image):


class _Qwen2VLTemplateMixin:
image_token_id = 151655
video_token_id = 151656

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
Expand Down Expand Up @@ -1595,16 +1597,17 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
labels = inputs['labels']
images = example.get('images') or []
videos = example.get('videos') or []
data = {}
for media_type in ['images', 'videos']:
if locals()[media_type]:
if media_type == 'images':
media_token = 151655
media_token = self.image_token_id
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
media_grid_thw = media_inputs['image_grid_thw']
else:
media_inputs = processor.image_processor(images=None, videos=videos, return_tensors='pt')
media_grid_thw = media_inputs['video_grid_thw']
media_token = 151656
media_token = self.video_token_id
idx_list = _findall(input_ids, media_token)
added_tokens_len = 0
for i, idx in enumerate(idx_list):
Expand All @@ -1617,32 +1620,51 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx
+ 1:]
added_tokens_len += token_len - 1
inputs.update(media_inputs)
data.update(media_inputs)

inputs['input_ids'] = input_ids
inputs['labels'] = labels
inputs['_data'] = {'plain_text': not images and not videos, 'input_ids': torch.tensor(input_ids)[None]}
data['input_ids'] = torch.tensor(input_ids)[None]
inputs['_data'] = data
return inputs, {}

def _post_encode(self, model, data: Any) -> Dict[str, Any]:
plain_text = data.pop('plain_text', False)
if is_deepspeed_enabled() and plain_text:
from PIL import Image
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
processor = self.tokenizer.processor
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
input_ids = data['input_ids']
device = input_ids.device
pixel_values = media_inputs['pixel_values'].to(device)
_model = model.model
if not hasattr(_model, 'embed_tokens'):
_model = _model.model # LoRA
inputs_embeds = _model.embed_tokens(input_ids)
pixel_values = pixel_values.type(model.visual.get_dtype())
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
inputs_embeds += image_embeds.mean() * 0.
return {'inputs_embeds': inputs_embeds[0]}
return {}
_model = model.model
if not hasattr(_model, 'embed_tokens'):
_model = _model.model # LoRA
input_ids = data['input_ids']
pixel_values = data.get('pixel_values')
pixel_values_videos = data.get('pixel_values_videos')
inputs_embeds = _model.embed_tokens(input_ids)
if pixel_values is None and pixel_values_videos is None: # plain-text
if is_deepspeed_enabled():
from PIL import Image
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
processor = self.tokenizer.processor
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
device = input_ids.device
pixel_values = media_inputs['pixel_values'].to(device)

pixel_values = pixel_values.type(model.visual.get_dtype())
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
inputs_embeds += image_embeds.mean() * 0.
else:
if pixel_values is not None:
image_grid_thw = data['image_grid_thw']
pixel_values = pixel_values.type(model.visual.get_dtype())
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if pixel_values_videos is not None:
video_grid_thw = data['video_grid_thw']
pixel_values_videos = pixel_values_videos.type(model.visual.get_dtype())
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
return {'inputs_embeds': inputs_embeds[0]}

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
Expand Down

0 comments on commit 5d64d39

Please sign in to comment.