Skip to content

Commit

Permalink
refactor internvl2 (modelscope#1625)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Aug 13, 2024
1 parent 68ead01 commit e27cdaf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 172 deletions.
151 changes: 10 additions & 141 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,8 +1362,6 @@ def get_model_tokenizer_paligemma_vision(model_dir: str,
model, tokenizer = get_model_tokenizer_from_repo(
model_dir, torch_dtype, model_kwargs, load_model, automodel_class=PaliGemmaForConditionalGeneration, **kwargs)
tokenizer.processor = processor
if model is not None:
model.max_position_embeddings = model.language_model.config.max_position_embeddings
return model, tokenizer


Expand Down Expand Up @@ -4037,85 +4035,6 @@ def _new_forward(*args, **kwargs):
embedding.forward = _new_forward


def _patch_internvl_forward(forward_func):
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss

def wrapper(
self,
pixel_values: torch.FloatTensor = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if pixel_values is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
return forward_func(
pixel_values,
input_ids,
attention_mask,
position_ids,
image_flags,
past_key_values,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
)

return wrapper


def patch_internvl_forward(model) -> None:
if not hasattr(model, '__old_forward'): # Avoid double patching
forward = model.forward
model.__old_forward = forward
model.forward = MethodType(_patch_internvl_forward(model.forward), model)


@register_model(
ModelType.internvl_chat_v1_5,
'AI-ModelScope/InternVL-Chat-V1-5',
Expand Down Expand Up @@ -4157,8 +4076,8 @@ def patch_internvl_forward(model) -> None:
TemplateType.internvl_phi3,
requires=['transformers>=4.35,<4.42', 'timm'],
support_flash_attn=True,
support_lmdeploy=True,
support_vllm=True,
eos_token='<|end|>',
placeholder_tokens=['<IMG_CONTEXT>'],
tags=['multi-modal', 'vision'],
hf_model_id='OpenGVLab/Mini-InternVL-Chat-4B-V1-5')
Expand Down Expand Up @@ -4198,6 +4117,7 @@ def patch_internvl_forward(model) -> None:
support_flash_attn=True,
support_lmdeploy=True,
support_vllm=True,
eos_token='<|end|>',
placeholder_tokens=['<IMG_CONTEXT>'],
tags=['multi-modal', 'vision'],
hf_model_id='OpenGVLab/InternVL2-4B')
Expand Down Expand Up @@ -4258,6 +4178,11 @@ def get_model_tokenizer_internvl(model_dir: str,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)
if kwargs.get('eos_token') is None:
del tokenizer.__class__.eos_token_id
tokenizer.eos_token = '<|im_end|>'

model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
use_flash_attn = kwargs.pop('use_flash_attn', False)
model_config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
Expand All @@ -4270,7 +4195,6 @@ def get_model_tokenizer_internvl(model_dir: str,
if isinstance(quantization_config, BitsAndBytesConfig):
use_bnb = True

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)
model, tokenizer = get_model_tokenizer_from_repo(
model_dir, torch_dtype, model_kwargs, load_model, tokenizer=tokenizer, model_config=model_config, **kwargs)

Expand All @@ -4280,64 +4204,9 @@ def get_model_tokenizer_internvl(model_dir: str,
model.language_model.output.state.force_no_igemmlt = True

if model is not None:
model.config.max_position_embeddings = model.language_model.config.max_position_embeddings
_use_submodel_func(model, 'language_model', ['get_input_embeddings', 'gradient_checkpointing_enable'])
func_list = ['generate', 'get_input_embeddings', 'gradient_checkpointing_enable', 'forward']
_use_submodel_func(model, 'language_model', func_list)
fix_internvl_inplace_bug(model)
patch_internvl_forward(model)

if not hasattr(model, '__old_generate'):
generate = model.generate
model.__old_generate = generate

@wraps(generate)
def _new_generate(*args, **kwargs):
kwargs.pop('image_flags', None)
return generate(*args, **kwargs)

model.generate = _new_generate

if not hasattr(model, '_old_extract_feature'):
extract_feature = model.extract_feature
model._old_extract_feature = extract_feature

@wraps(extract_feature)
def _new_extract_feature(pixel_values):
return extract_feature(pixel_values).to(pixel_values.device).to(pixel_values.dtype)

model.extract_feature = _new_extract_feature

if not hasattr(model.language_model, '__old_forward'): # Avoid double patching
old_forward = model.language_model.forward
model.language_model.__old_forward = old_forward

@wraps(old_forward)
def _new_forward(*args, **kwargs):
input_ids: Optional[Tensor] = kwargs.get('input_ids', None)
input_embeds: Optional[Tensor] = kwargs.get('inputs_embeds', None)
device = input_ids.device if input_ids is not None else input_embeds.device
output = old_forward(*args, **kwargs)
output['logits'] = output['logits'].to(device)
return output

model.language_model.forward = _new_forward

IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
model.img_context_token_id = img_context_token_id
if not hasattr(model.config, 'hidden_size'):
model.config.hidden_size = model.config.llm_config.hidden_size
# fix single GPU bug
if not hasattr(dist, '_old_get_rank'):
get_rank = dist.get_rank

@wraps(get_rank)
def new_get_rank(group=None):
if not dist.is_initialized() or dist.get_world_size() == 1:
return -1
return get_rank(group)

dist.get_rank = new_get_rank
dist._old_get_rank = get_rank
return model, tokenizer


Expand Down Expand Up @@ -4370,7 +4239,7 @@ def new_get_rank(group=None):
TemplateType.internlm_xcomposer2_4khd,
support_flash_attn=True,
support_lmdeploy=True,
eos_token='[UNUSED_TOKEN_145]',
eos_token='<|im_end|>',
function_kwargs={'version': 'v2-4khd'},
tags=['multi-modal', 'vision'],
hf_model_id='internlm/internlm-xcomposer2-4khd-7b')
Expand Down
56 changes: 25 additions & 31 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,11 +1641,12 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
return inputs, {}
input_ids = inputs['input_ids']
idx_list = _findall(input_ids, -100)
labels = inputs.get('labels')
pixel_values = None
images = example.get('images')
if images:
labels = inputs.get('labels')
pixel_values_images = [transform_image(image) for image in images]
pixel_values = torch.cat(pixel_values_images, dim=0)
pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model.dtype)
image_bs = pixel_values.shape[0]

idx, idx2 = idx_list[0], idx_list[-1] # remove [-100, -100]
Expand All @@ -1656,20 +1657,22 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
labels = labels[:idx] + [-100] * len(img_tokens) + labels[idx2 + 1:]
inputs['input_ids'] = input_ids
inputs['labels'] = labels

inputs['pixel_values'] = pixel_values.to(self.model.dtype)
inputs['image_flags'] = torch.ones(image_bs)

inputs['_data'] = {'input_ids': torch.tensor(input_ids), 'pixel_values': pixel_values}
inputs.pop('loss_scale', None)
return inputs, {}

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
if any('pixel_values' in b for b in batch):
image_flags = [b['image_flags'] for b in batch if 'image_flags' in b]
if image_flags:
res['image_flags'] = torch.concat(image_flags)
return res
def _post_encode(self, data: Any) -> Dict[str, Any]:
embedding = self.model.get_input_embeddings()
device = embedding.weight.device
input_ids = data['input_ids']
inputs_embeds = embedding(input_ids)
pixel_values = data['pixel_values']
if pixel_values is not None:
pixel_values = pixel_values.to(device=device)
vit_embeds = self.model.extract_feature(pixel_values)
selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
return {'inputs_embeds': inputs_embeds}

@staticmethod
def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
Expand All @@ -1680,12 +1683,6 @@ class Internvl2Template(InternvlTemplate):
video_segments = 8
system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'

def __init__(self):
Template.__init__(
self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'], ['<|im_end|>'],
self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>'],
auto_add_bos=True)

def replace_tag(self, media_type, index, example) -> List[Context]:
if self._is_vllm:
image_context = ['<img><image></img>\n']
Expand Down Expand Up @@ -1737,10 +1734,9 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
has_video = bool(example.get('videos'))
pixel_values = [transform_image(image, max_num=1 if has_video else 12) for image in images]
num_patches = [pv.shape[0] for pv in pixel_values]
pixel_values = torch.cat(pixel_values)
inputs['pixel_values'] = pixel_values.to(self.model.dtype)
inputs['image_flags'] = torch.ones(sum(num_patches))
pixel_values = torch.cat(pixel_values).to(self.model.dtype)
else:
pixel_values = None
num_patches = []
assert len(num_patches) == len(
idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
Expand All @@ -1755,12 +1751,12 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
added_tokens_len += len(img_tokens) - 1
inputs['input_ids'] = input_ids
inputs['labels'] = labels
inputs['_data'] = {'input_ids': torch.tensor(input_ids), 'pixel_values': pixel_values}
inputs.pop('loss_scale', None)
return inputs, {}


class InternvlPhi3Template(InternvlTemplate):
system = 'You are an AI assistant whose name is Phi-3.'
class InternvlPhi3TemplateMixin:

def __init__(self):
Template.__init__(
Expand All @@ -1770,14 +1766,12 @@ def __init__(self):
self.padding_side = 'left'


class Internvl2Phi3Template(Internvl2Template):
class InternvlPhi3Template(InternvlPhi3TemplateMixin, InternvlTemplate):
system = 'You are an AI assistant whose name is Phi-3.'

def __init__(self):
Template.__init__(
self, [], ['<|user|>\n{{QUERY}}<|end|><|assistant|>\n'], ['<|end|>'], ['<|end|>'],
self.system, ['<|system|>\n{{SYSTEM}}<|end|>'],
auto_add_bos=True)
self.padding_side = 'left'

class Internvl2Phi3Template(InternvlPhi3TemplateMixin, Internvl2Template):
pass


register_template(
Expand Down

0 comments on commit e27cdaf

Please sign in to comment.