Skip to content

Commit

Permalink
support Gemma2
Browse files Browse the repository at this point in the history
  • Loading branch information
runninglsy committed Jul 29, 2024
1 parent 7c8f0dc commit cb48de1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
36 changes: 34 additions & 2 deletions ovis/model/modeling_ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor, LongTensor, IntTensor
from torch.nn import init
from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import HybridCache
from transformers.generation.utils import GenerateOutput
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config

Expand Down Expand Up @@ -46,7 +47,10 @@ def __init__(self, config: OvisConfig, *inputs, **kwargs):
self.visual_tokenizer = kwargs['visual_tokenizer']
self.config.visual_tokenizer_config = self.visual_tokenizer.config
else:
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config)
attn_kwargs = dict()
if kwargs.get('train_attn_implementation', None) is not None:
attn_kwargs['attn_implementation'] = kwargs.pop('train_attn_implementation')
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
self.visual_tokenizer = AutoModel.from_config(self.config.visual_tokenizer_config,
Expand Down Expand Up @@ -269,6 +273,33 @@ def save_pretrained(
# safe_serialization=safe_serialization)
# self.get_visual_tokenizer().get_image_processor().save_pretrained(visual_tokenizer_directory)

def _get_hybrid_cache_for_llm(self, max_batch_size: int, max_cache_len: int):
cache_cls = HybridCache
llm = self.get_llm()

need_new_cache = (
not hasattr(llm, "_cache")
or (not isinstance(llm._cache, cache_cls))
or llm._cache.max_batch_size != max_batch_size
or llm._cache.max_cache_len < max_cache_len
)

if need_new_cache:
if hasattr(llm.config, "_pre_quantization_dtype"):
cache_dtype = llm.config._pre_quantization_dtype
else:
cache_dtype = llm.dtype
llm._cache = cache_cls(
config=llm.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=llm.device,
dtype=cache_dtype,
)
else:
llm._cache.reset()
return llm._cache

# TODO: support batch generation
def generate(
self,
Expand All @@ -283,7 +314,8 @@ def generate(
pixel_values=kwargs.pop('pixel_values')
)
if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
kwargs['past_key_values'] = self.get_llm()._get_cache('hybrid', getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
self.get_llm()._supports_cache_class = True
kwargs['cache_implementation'] = None

Expand Down
3 changes: 2 additions & 1 deletion ovis/train/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ class TrainingArguments(transformers.TrainingArguments):
optim: str = field(default="adamw_torch")
visual_max_tau: float = field(default=5.0)
visual_min_tau: float = field(default=0.05)
save_safetensors: bool = field(default=False)
save_safetensors: bool = field(default=True)
monitor_step: int = field(default=100)
visual_re_init_layer_begin: Optional[int] = field(default=None)
vte_re_init: bool = field(default=False)
text_max_length: int = field(default=1024)
train_attn_implementation: Optional[str] = field(default=None)

def __post_init__(self):
if self.gradient_checkpointing:
Expand Down
6 changes: 5 additions & 1 deletion ovis/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def args2dict(args):
conversation_formatter_class=model_args.conversation_formatter_class
)
# 2. load pretrained llm and text tokenizer
llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path)
attn_kwargs = dict()
if training_args.train_attn_implementation is not None:
attn_kwargs['attn_implementation'] = training_args.train_attn_implementation
llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs)
text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path)
if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None:
text_tokenizer.pad_token_id = model_args.pad_token_id
Expand Down Expand Up @@ -87,6 +90,7 @@ def args2dict(args):
else: # load pretrained ovis model (S2, S3)
model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path,
multimodal_max_length=model_args.multimodal_max_length,
train_attn_implementation=training_args.train_attn_implementation,
output_loading_info=True)
rank0_print(BEGIN_LINE)
rank0_print(f'Loading info of Ovis:\n{loading_info}')
Expand Down

0 comments on commit cb48de1

Please sign in to comment.