Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: InternVL2-26B tensor_parallel_size=4, AssertionError: 25 is not divisible by 4 #8097

Closed
1 task done
SovereignRemedy opened this issue Sep 3, 2024 · 3 comments · Fixed by #8061
Closed
1 task done
Labels
bug Something isn't working

Comments

@SovereignRemedy
Copy link

Your current environment

Python3.8
8*A10 GPU
Model:InternVL2-26B
vllm branch:main
torch 2.4.0
torchvision 0.19.0

🐛 Describe the bug

ref:
#8055 (comment)
#7996

This issue solves some of the inference problems of Intern VL2, but there are still problems in multi-card parallel situations.

(VllmWorkerProcess pid=29708) ERROR 09-03 12:08:10 multiproc_worker_utils.py:226]
[rank0]: Traceback (most recent call last):
[rank0]:   File "offline_inference_vision_language.py", line 240, in <module>
[rank0]:     main(args)
[rank0]:   File "offline_inference_vision_language.py", line 190, in main
[rank0]:     llm, prompt, stop_token_ids = model_example_map[model](question)
[rank0]:   File "offline_inference_vision_language.py", line 135, in run_internvl
[rank0]:     llm = LLM(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/entrypoints/llm.py", line 177, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/engine/llm_engine.py", line 541, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/engine/llm_engine.py", line 302, in __init__
[rank0]:     self.model_executor = executor_class(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/executor/distributed_gpu_executor.py", line 26, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/executor/executor_base.py", line 47, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/executor/multiproc_gpu_executor.py", line 125, in _init_executor
[rank0]:     self._run_workers("load_model",
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/executor/multiproc_gpu_executor.py", line 199, in _run_workers
[rank0]:     driver_worker_output = driver_worker_method(*args, **kwargs)
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/worker/worker.py", line 182, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/worker/model_runner.py", line 915, in load_model
[rank0]:     self.model = get_model(model_config=self.model_config,
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
[rank0]:     return loader.load_model(model_config=model_config,
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/model_loader/loader.py", line 341, in load_model
[rank0]:     model = _initialize_model(model_config, self.load_config,
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/model_loader/loader.py", line 170, in _initialize_model
[rank0]:     return build_model(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/model_loader/loader.py", line 155, in build_model
[rank0]:     return model_class(config=hf_config,
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/models/internvl.py", line 328, in __init__
[rank0]:     self.vision_model = InternVisionModel(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/models/intern_vit.py", line 252, in __init__
[rank0]:     self.encoder = InternVisionEncoder(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/models/intern_vit.py", line 228, in __init__
[rank0]:     self.layers = nn.ModuleList([
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/models/intern_vit.py", line 229, in <listcomp>
[rank0]:     InternVisionEncoderLayer(config=config, quant_config=quant_config)
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/models/intern_vit.py", line 190, in __init__
[rank0]:     self.attn = InternAttention(config, quant_config=quant_config)
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/models/intern_vit.py", line 104, in __init__
[rank0]:     self.qkv = QKVParallelLinear(
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/model_executor/layers/linear.py", line 659, in __init__
[rank0]:     self.num_heads = divide(self.total_num_heads, tp_size)
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/distributed/utils.py", line 24, in divide
[rank0]:     ensure_divisibility(numerator, denominator)
[rank0]:   File "/home/tdj/intervl/temp_vllm/vllm/vllm/distributed/utils.py", line 17, in ensure_divisibility
[rank0]:     assert numerator % denominator == 0, "{} is not divisible by {}".format(
[rank0]: AssertionError: 25 is not divisible by 4
ERROR 09-03 12:08:10 multiproc_worker_utils.py:120] Worker VllmWorkerProcess pid 29708 died, exit code: -15

my code

# InternVL
def run_internvl(question):

    model_name = "/home/tdj/model/InternVL2-26B"

    llm = LLM(
        model=model_name,
#        dtype = "half",
        trust_remote_code=True,
        gpu_memory_utilization=0.9,
        tensor_parallel_size=4,
        max_num_batched_tokens=8192,
        max_model_len=4096,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
    return llm, prompt, stop_token_ids

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@SovereignRemedy SovereignRemedy added the bug Something isn't working label Sep 3, 2024
@DarkLight1337
Copy link
Member

Should be solved by #8061 which automatically disables TP in this case.

@DarkLight1337
Copy link
Member

See #8061 (comment)

@SovereignRemedy
Copy link
Author

SovereignRemedy commented Sep 3, 2024

See #8061 (comment)

sorry I will continue to pay attention to this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants