Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for deepseek vl2 #401

Merged
merged 5 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,24 @@ def convert_model(self, model: nn.Module):
data_type = quantization_config.data_type if hasattr(quantization_config,
"data_type") else "int" # pragma: no cover
sym = quantization_config.sym
to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config,
"to_quant_block_names") else None

quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config,
"quant_block_list") else None
if to_quant_block_names is None: # TODO check compatibility
all_blocks = get_block_names(model)
else:
all_blocks = get_multimodal_block_names(model, quant_vision=True)

if quant_block_list is None:
quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)
to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config,
"to_quant_block_names") else None
if to_quant_block_names is not None:
if isinstance(to_quant_block_names, (list, tuple)):
quant_block_list = to_quant_block_names
else:
quant_block_list = []
for block in to_quant_block_names.split(','):
quant_block_list.append([f'{block}.{i}' for i in range(len(get_module(model, block)))])
else:
all_blocks = get_block_names(model)
quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)

layer_names = get_layer_names_in_block(model, quant_block_list=quant_block_list)

extra_config = {}
Expand Down
1 change: 1 addition & 0 deletions auto_round/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def get_template(template_or_path: str, model=None, tokenizer=None, processor=No
else:
logger.warning(f"Unable to recognize {template_or_path}, using default template instead.")
template = TEMPLATES["default"]
template.model_type = template_or_path

template.processor.post_init(model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)

Expand Down
57 changes: 33 additions & 24 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,32 +281,41 @@ def tune(args):

# load_model
processor, image_processor = None, None
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration":
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
if "llava" in model_type:
from transformers import LlavaForConditionalGeneration
cls = LlavaForConditionalGeneration
elif "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
cls = Qwen2VLForConditionalGeneration
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
else:
cls = AutoModelForCausalLM

model = cls.from_pretrained(
if "deepseek-vl2" in model_name.lower():
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM # pylint: disable=E0401
processor = DeepseekVLV2Processor.from_pretrained(model_name)
tokenizer = processor.tokenizer
model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
model_type = "deepseek_vl_v2"
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration":
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
if "llava" in model_type:
from transformers import LlavaForConditionalGeneration
cls = LlavaForConditionalGeneration
elif "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
cls = Qwen2VLForConditionalGeneration
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
else:
cls = AutoModelForCausalLM

model = cls.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
if "cogvlm2" in model_name:
model.config.model_type = "cogvlm2"

Expand Down
14 changes: 14 additions & 0 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@
"idefics3"
]

def _get_deepseek_vl2_multimodal_block(model, quant_vision=False):
model.forward = model.language.forward
block_names = []
if quant_vision:
block_names.append([f"vision.blocks.{i}" for i in range(len(model.vision.blocks))])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better support a arg to pass the vision block name, could be done later

block_names.append([f"projector.layers.{i}" for i in range(len(model.projector.layers))])
block_names.append([f"language.model.layers.{i}" for i in range(len(model.language.model.layers))])
return block_names

SPECIAL_MULTIMODAL_BLOCK = {
"deepseek_vl_v2": _get_deepseek_vl2_multimodal_block
}


def to_device(input, device=torch.device("cpu")):
"""Moves input data to the specified device.

Expand Down
4 changes: 3 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from functools import lru_cache
from packaging import version
import gc
from .special_model_handler import shareable_keywords
from .special_model_handler import shareable_keywords, SPECIAL_MULTIMODAL_BLOCK


@lru_cache(None)
Expand Down Expand Up @@ -402,6 +402,8 @@ def get_multimodal_block_names(model, quant_vision=False):
Returns:
block_names: A list whose elements are list of block's layer names
"""
if hasattr(model, "config") and model.config.model_type in SPECIAL_MULTIMODAL_BLOCK.keys():
return SPECIAL_MULTIMODAL_BLOCK.get(model.config.model_type)(model, quant_vision=quant_vision)
block_names = []
target_modules = []
vison_blocks_tuple = ("vision", "visual",)
Expand Down
66 changes: 65 additions & 1 deletion test_cuda/test_support_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class TestSupportVLMS(unittest.TestCase):
@classmethod
def setUpClass(self):
self.save_dir = os.path.join(os.path.dirname(__file__), "./ut_saved")
self.save_dir = os.path.join(os.path.dirname(__file__), "ut_saved")
self.python_path = sys.executable
self.device = 0

Expand Down Expand Up @@ -333,6 +333,70 @@ def test_72b(self):
)
self.assertFalse(res > 0 or res == -1, msg="qwen2-72b tuning fail")
shutil.rmtree(self.save_dir, ignore_errors=True)

def test_deepseek_vl2(self):
model_path = "/models/deepseek-vl2-tiny"
res = os.system(
f"cd .. && {self.python_path} -m auto_round --mllm "
f"--model {model_path} --iter 3 --nsamples 10 --bs 4 --output_dir {self.save_dir} --device auto --group_size 32 "
f"--fp_layers language.model.layers.4,language.model.layers.6"
)
self.assertFalse(res > 0 or res == -1, msg="deepseek vl2 tuning fail")

quantized_model_path = os.path.join(self.save_dir, "deepseek-vl2-tiny-w4g32-auto_round")
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
from transformers import AutoModelForCausalLM
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(quantized_model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
quantized_model_path,
trust_remote_code=True,
device_map=f"cuda:{self.device}",
torch_dtype="auto",
)
vl_gpt = vl_gpt.eval()

image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
content = "Describe this image."

## single image conversation example
conversation = [
{
"role": "<|User|>",
"content": content,
},
{"role": "<|Assistant|>", "content": ""},
]

# load images and prepare for inputs
pil_images = Image.open(requests.get(image_url, stream=True).raw)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=[pil_images],
force_batchify=True,
system_prompt=""
)
prepare_inputs = prepare_inputs.to(vl_gpt.device)

# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

# run the model to get the response
outputs = vl_gpt.language.generate(
input_ids = prepare_inputs["input_ids"],
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)

answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(f"{prepare_inputs['sft_format'][0]}", answer)

if __name__ == "__main__":
unittest.main()
Loading