Skip to content

Conversation

@afierka-intel
Copy link

@afierka-intel afierka-intel commented Apr 16, 2025

Original PR #897

bmyrcha and others added 12 commits April 8, 2025 09:51
Switched execution of versioned branches to _next and added logs
redirection to file.
Adjusted method of extracting synapse build id for release branches
This PR implements HPU support for pipeline parallelism. Tested accuracy
and it's the same as TP accuracy on:
- Llama3.1-70b-Instruct
- Llama3.2-3b-Instruct
- Mixtral-8x7b

To serve with PP:
`VLLM_DECODE_BS_BUCKET_MIN=384 VLLM_DECODE_BLOCK_BUCKET_MAX=896 vllm
serve /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-70B-Instruct/
--tensor-parallel-size 1 --pipeline-parallel-size 4 --max-num-seqs 384
--disable-log-requests --dtype bfloat16 --gpu-memory-util 0.9
--disable-log-stats --num_scheduler_steps 1 --max-num-batched-tokens
2048 --max-model-len 256 --block-size 128`

Known issues:
* since for Pipeline Parallelism max_num_seqs acts as a microbatch for a
single virtual_engine - for bigger batch_size we fall into a very
specific corner case and get flat_pa error -> set batch_size to
approximately batch size that you would use in TP but divided by pp_size
* delayed sampling is not yet compatible with pipeline parallelism
* virtaul_engine ID is passed to HPUGraph which results in pp_size *
amount of graphs

Signed-off-by: jmaksymczuk <jmaksymczuk@habana.ai>
Co-authored-by: Rafal Litka <rlitka@habana.ai>
Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
Cherry-pick of #1023

Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
Cherry-pick of #921

Co-authored-by: Konrad Zawora <kzawora@habana.ai>
Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
Co-authored-by: Iryna Boiko <iboiko@habana.ai>
Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
migrated from a PR to habana_main:
#1014

For Best performance, this PR is recommended to run with INC:
[[SW-223553] [VLLM] Merge deepseek changes into habana_main - Habana
Labs](https://jira.habana-labs.com/browse/SW-223553)

**test acc of G3**:
```bash
huggingface-cli download Yi30/inc-woq-default-pile-one-cache-408  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./inc-woq-default-pile-one-cache-408-for-fp8-mla/inc_measure_output"
}


QUANT_CONFIG=inc_quant_with_fp8kv_config.json \
PT_HPU_LAZY_MODE=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
lm_eval --model vllm \
  --model_args "pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc" \
  --tasks gsm8k --num_fewshot "5" --limit "256" \
  --batch_size "8"
```

**test acc of G2**:
**convert original DeepSeek-R1** using
[convert_for_g2.py](https://github.com/yangulei/vllm-fork/blob/deepseek_r1_g2/scripts/convert_for_g2.py)
(this step will be removed as INC updates.)

```bash

huggingface-cli download Yi30/inc-woq-default-pile-one-cache-412-g2  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
}
```


vllm
(pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc),
gen_kwargs: (None), limit: 256.0, num_fewshot: 5, batch_size: 128
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9492|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.9453|± |0.0142|


----------
Need to use vllm-hpu-extension:
https://github.com/HabanaAI/vllm-hpu-extension/tree/dev/chendi/deepseek_r1

Status:

runnable with Deepseek-R1.
Accuracy check: for block fp8 weight => garbage output
accuracy check for BF16 weight => looks good.

test scripts:
```
from vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
#os.environ['HABANA_LOGS']="vllm_inc_debug"
#os.environ["LOG_LEVEL_ALL"]="3"
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
#os.environ["QUANT_CONFIG"] = "inc_quant_with_fp8kv_config.json"
#os.environ["LOGLEVEL"] = "DEBUG"

 
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
 
if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)
 
    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"
 
    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=1024,
            max_num_seqs=1,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)
 
    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    if os.environ.get("QUANT_CONFIG", None) is not None:
        llm.llm_engine.model_executor.shutdown()
```

---------

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: kwisniewski98 <kwisniewski@habana.ai>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: kwisniewski98 <kwisniewski@habana.ai>
Same PR as #996.
Just for v1.21.0_next branch.
)

The make_attn_bias in hpu_model_runner doesn't cover the non-causal
embedding model mask set and also vertical mask off is not set when
merged prefill is enabled.
synchronize 12 vLLM flags to non-driver workers in Ray executor

FIX "not warmed-up" bucket issue in cross-node vLLM inference.

Root cause: the issue is caused by not synchronizing the 12 vLLM flags
to all the non-driver workers within the Ray cluster


![image](https://github.com/user-attachments/assets/fb51cefc-b23a-434d-a641-493592d896a6)

---------

Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.