From cb48de103ab569ecb12909de08dcef5e59e8a696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B6=A6=E5=AE=81?= Date: Mon, 29 Jul 2024 14:43:26 +0800 Subject: [PATCH] support Gemma2 --- ovis/model/modeling_ovis.py | 36 ++++++++++++++++++++++++++++++++++-- ovis/train/arguments.py | 3 ++- ovis/train/train.py | 6 +++++- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/ovis/model/modeling_ovis.py b/ovis/model/modeling_ovis.py index a3b0aec..1e37ee3 100644 --- a/ovis/model/modeling_ovis.py +++ b/ovis/model/modeling_ovis.py @@ -8,6 +8,7 @@ from torch import Tensor, LongTensor, IntTensor from torch.nn import init from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM +from transformers.cache_utils import HybridCache from transformers.generation.utils import GenerateOutput from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config @@ -46,7 +47,10 @@ def __init__(self, config: OvisConfig, *inputs, **kwargs): self.visual_tokenizer = kwargs['visual_tokenizer'] self.config.visual_tokenizer_config = self.visual_tokenizer.config else: - self.llm = AutoModelForCausalLM.from_config(self.config.llm_config) + attn_kwargs = dict() + if kwargs.get('train_attn_implementation', None) is not None: + attn_kwargs['attn_implementation'] = kwargs.pop('train_attn_implementation') + self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs) assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch" self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path) self.visual_tokenizer = AutoModel.from_config(self.config.visual_tokenizer_config, @@ -269,6 +273,33 @@ def save_pretrained( # safe_serialization=safe_serialization) # self.get_visual_tokenizer().get_image_processor().save_pretrained(visual_tokenizer_directory) + def _get_hybrid_cache_for_llm(self, max_batch_size: int, max_cache_len: int): + cache_cls = HybridCache + llm = self.get_llm() + + need_new_cache = ( + not hasattr(llm, "_cache") + or (not isinstance(llm._cache, cache_cls)) + or llm._cache.max_batch_size != max_batch_size + or llm._cache.max_cache_len < max_cache_len + ) + + if need_new_cache: + if hasattr(llm.config, "_pre_quantization_dtype"): + cache_dtype = llm.config._pre_quantization_dtype + else: + cache_dtype = llm.dtype + llm._cache = cache_cls( + config=llm.config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=llm.device, + dtype=cache_dtype, + ) + else: + llm._cache.reset() + return llm._cache + # TODO: support batch generation def generate( self, @@ -283,7 +314,8 @@ def generate( pixel_values=kwargs.pop('pixel_values') ) if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2 - kwargs['past_key_values'] = self.get_llm()._get_cache('hybrid', getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2]) + kwargs['past_key_values'] = self._get_hybrid_cache_for_llm( + getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2]) self.get_llm()._supports_cache_class = True kwargs['cache_implementation'] = None diff --git a/ovis/train/arguments.py b/ovis/train/arguments.py index 6d02a17..598b178 100644 --- a/ovis/train/arguments.py +++ b/ovis/train/arguments.py @@ -34,11 +34,12 @@ class TrainingArguments(transformers.TrainingArguments): optim: str = field(default="adamw_torch") visual_max_tau: float = field(default=5.0) visual_min_tau: float = field(default=0.05) - save_safetensors: bool = field(default=False) + save_safetensors: bool = field(default=True) monitor_step: int = field(default=100) visual_re_init_layer_begin: Optional[int] = field(default=None) vte_re_init: bool = field(default=False) text_max_length: int = field(default=1024) + train_attn_implementation: Optional[str] = field(default=None) def __post_init__(self): if self.gradient_checkpointing: diff --git a/ovis/train/train.py b/ovis/train/train.py index c8e926e..c1cedcb 100644 --- a/ovis/train/train.py +++ b/ovis/train/train.py @@ -52,7 +52,10 @@ def args2dict(args): conversation_formatter_class=model_args.conversation_formatter_class ) # 2. load pretrained llm and text tokenizer - llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path) + attn_kwargs = dict() + if training_args.train_attn_implementation is not None: + attn_kwargs['attn_implementation'] = training_args.train_attn_implementation + llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs) text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path) if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None: text_tokenizer.pad_token_id = model_args.pad_token_id @@ -87,6 +90,7 @@ def args2dict(args): else: # load pretrained ovis model (S2, S3) model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path, multimodal_max_length=model_args.multimodal_max_length, + train_attn_implementation=training_args.train_attn_implementation, output_loading_info=True) rank0_print(BEGIN_LINE) rank0_print(f'Loading info of Ovis:\n{loading_info}')