Skip to content

Commit

Permalink
support evaluation for qwenvl and internvl (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengtao-lv authored Dec 16, 2024
1 parent 83e37db commit 7c5dee4
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 10 deletions.
2 changes: 1 addition & 1 deletion llmc/eval/eval_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,6 @@ def _adjust_config(task_dict):
results['date'] = datetime_str
# add_env_info(results) # additional environment info to results
# add_tokenizer_info(results, lm) # additional info about tokenizer
return make_table(results)
return '\n' + make_table(results)
else:
return None
83 changes: 81 additions & 2 deletions llmc/models/internvl2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from datetime import timedelta
from typing import Optional

import torch
import torchvision.transforms as T
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from accelerate.utils import InitProcessGroupKwargs
from lmms_eval.api.model import lmms
from lmms_eval.models.internvl2 import InternVL2 as LMMS_InternVL2
from loguru import logger
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
GenerationConfig)

from llmc.utils.registry_factory import MODEL_REGISTRY

Expand Down Expand Up @@ -184,7 +191,7 @@ def __init__(self, config, device_map=None, use_cache=False):

class InternVL2SharedBehavior():
def build_model(self):
self.eval_name = 'InternVL2'
self.eval_name = 'InternVL2Eval'
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
Expand Down Expand Up @@ -364,3 +371,75 @@ def get_subsets_in_block(self, block):
]
else:
raise Exception(f'InternVL2 do not support {self.get_modality()} modality.')


@MODEL_REGISTRY
class InternVL2Eval(LMMS_InternVL2):
def __init__(
self,
llmc_model,
pretrained: str = 'OpenGVLab/InternVL2-2B',
modality: str = 'image',
device: str = 'cuda:0',
device_map: str = 'cuda:0',
batch_size: str = '1',
**kwargs,
):
lmms.__init__(self)

self.path = pretrained
self._model = llmc_model.cuda()
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)

batch_size = int(batch_size)
assert batch_size == 1, f'Batch size should be 1 for InternVL2, but got {batch_size}.'
self.batch_size_per_gpu = batch_size

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'
elif accelerator.num_processes == 1 and device_map == 'auto':
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'

if accelerator.num_processes > 1:
assert accelerator.distributed_type in \
[DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], \
'Unsupported distributed type provided. Only DDP and FSDP are supported.'

if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
must_match=True, **kwargs)
logger.info('Detected that you are using DistributedType.DEEPSPEED.')

if accelerator.distributed_type == DistributedType.FSDP or \
accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == 'auto':
logger.info(f'Using {accelerator.num_processes} devices with tensor parallelism')
self._rank = 0
self._word_size = 1
else:
logger.info(f'Using single device: {self._device}')
self.model.to(self._device)
self._rank = 0
self._world_size = 1

self.modality = modality
9 changes: 3 additions & 6 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from lmms_eval.api.model import CacheHook
from lmms_eval.api.model import lmms
from lmms_eval.models.llava_hf import LlavaHf
from loguru import logger
from PIL import Image
Expand All @@ -19,7 +19,6 @@
class Llava(Llama):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
self.eval_name = 'LlavaHfEval'

def build_model(self):
self.vlm_model_config = AutoConfig.from_pretrained(
Expand All @@ -34,6 +33,7 @@ def build_model(self):
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
self.eval_name = 'LlavaHfEval'
self.mm_model = self.vlm_model
logger.info(f'self.vlm_model : {self.vlm_model}')
self.vision_model = self.vlm_model.vision_tower
Expand Down Expand Up @@ -181,10 +181,7 @@ def __init__(
**kwargs,
) -> None:

self._rank = 0
self._world_size = 1
self.cache_hook = CacheHook(None)
self.task_dict = {}
lmms.__init__(self)
# Do not use kwargs for now
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'

Expand Down
70 changes: 69 additions & 1 deletion llmc/models/qwen2vl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import inspect
from typing import Optional, Union

import torch
import torch.nn as nn
from accelerate import Accelerator, DistributedType
from lmms_eval.api.model import lmms
from lmms_eval.models.qwen2_vl import Qwen2_VL
from loguru import logger
from transformers import AutoConfig, AutoProcessor
from transformers import AutoConfig, AutoProcessor, AutoTokenizer

try:
from transformers import Qwen2VLForConditionalGeneration
Expand Down Expand Up @@ -31,6 +36,7 @@ def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def build_model(self):
self.eval_name = 'Qwen2VLEval'
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
Expand Down Expand Up @@ -203,3 +209,65 @@ def forward(self, *args, **kwargs):
first_block_input['kwargs'].append(kwargs)
raise ValueError
return Catcher


@MODEL_REGISTRY
class Qwen2VLEval(Qwen2_VL):
def __init__(
self,
llmc_model,
pretrained: str = 'Qwen/Qwen2-VL-7B-Instruct',
device: Optional[str] = 'cuda',
device_map: Optional[str] = 'cuda',
batch_size: Optional[Union[int, str]] = 1,
use_cache=True,
use_flash_attention_2: Optional[bool] = False,
max_pixels: int = 12845056,
min_pixels: int = 3136,
max_num_frames: int = 32,
**kwargs,
) -> None:
lmms.__init__(self)
# Do not use kwargs for now
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'

accelerator = Accelerator()
if accelerator.num_processes > 1:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'
elif accelerator.num_processes == 1 and device_map == 'auto':
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'

self._model = llmc_model.eval().cuda()
self.processor = AutoProcessor.from_pretrained(pretrained,
max_pixels=max_pixels, min_pixels=min_pixels)
self.max_pixels = max_pixels
self.min_pixels = min_pixels
self.max_num_frames = max_num_frames
self._tokenizer = AutoTokenizer.from_pretrained(pretrained)

self._config = self.model.config
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], 'Unsupported distributed type provided. Only DDP and FSDP are supported.'
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self._rank = 0
self._word_size = 1

0 comments on commit 7c5dee4

Please sign in to comment.