Skip to content

Conversation

@jmaksymczuk
Copy link

@jmaksymczuk jmaksymczuk commented Apr 2, 2025

This PR implements HPU support for pipeline parallelism.
Tested accuracy and it's the same as TP accuracy on:

  • Llama3.1-70b-Instruct
  • Llama3.2-3b-Instruct
  • Mixtral-8x7b

To serve with PP:
VLLM_DECODE_BS_BUCKET_MIN=384 VLLM_DECODE_BLOCK_BUCKET_MAX=896 vllm serve /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-70B-Instruct/ --tensor-parallel-size 1 --pipeline-parallel-size 4 --max-num-seqs 384 --disable-log-requests --dtype bfloat16 --gpu-memory-util 0.9 --disable-log-stats --num_scheduler_steps 1 --max-num-batched-tokens 2048 --max-model-len 256 --block-size 128

Known issues:

  • since for Pipeline Parallelism max_num_seqs acts as a microbatch for a single virtual_engine - for bigger batch_size we fall into a very specific corner case and get flat_pa error -> set batch_size to approximately batch size that you would use in TP but divided by pp_size
  • delayed sampling is not yet compatible with pipeline parallelism
  • virtaul_engine ID is passed to HPUGraph which results in pp_size * amount of graphs

Copy link

@kwisniewski98 kwisniewski98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@michalkuligowski
Copy link

/run-gaudi-tests

@michalkuligowski
Copy link

/run-gaudi-tests

@michalkuligowski
Copy link

/run-gaudi-tests

@jmaksymczuk jmaksymczuk merged commit 2cf9580 into habana_main Apr 9, 2025
44 checks passed
@jmaksymczuk jmaksymczuk deleted the hpu_pp_new branch April 9, 2025 08:41
jmaksymczuk pushed a commit that referenced this pull request Apr 9, 2025
This PR implements HPU support for pipeline parallelism.
Tested accuracy and it's the same as TP accuracy on:
- Llama3.1-70b-Instruct
- Llama3.2-3b-Instruct
- Mixtral-8x7b

To serve with PP:
`VLLM_DECODE_BS_BUCKET_MIN=384 VLLM_DECODE_BLOCK_BUCKET_MAX=896 vllm
serve /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-70B-Instruct/
--tensor-parallel-size 1 --pipeline-parallel-size 4 --max-num-seqs 384
--disable-log-requests --dtype bfloat16 --gpu-memory-util 0.9
--disable-log-stats --num_scheduler_steps 1 --max-num-batched-tokens
2048 --max-model-len 256 --block-size 128`

Known issues:
* since for Pipeline Parallelism max_num_seqs acts as a microbatch for a
single virtual_engine - for bigger batch_size we fall into a very
specific corner case and get flat_pa error -> set batch_size to
approximately batch size that you would use in TP but divided by pp_size
* delayed sampling is not yet compatible with pipeline parallelism
* virtaul_engine ID is passed to HPUGraph which results in pp_size *
amount of graphs

---------

Signed-off-by: jmaksymczuk <jmaksymczuk@habana.ai>
Co-authored-by: Rafal Litka <rlitka@habana.ai>
michalkuligowski added a commit that referenced this pull request Apr 10, 2025
This PR implements HPU support for pipeline parallelism. Tested accuracy
and it's the same as TP accuracy on:
- Llama3.1-70b-Instruct
- Llama3.2-3b-Instruct
- Mixtral-8x7b

To serve with PP:
`VLLM_DECODE_BS_BUCKET_MIN=384 VLLM_DECODE_BLOCK_BUCKET_MAX=896 vllm
serve /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-70B-Instruct/
--tensor-parallel-size 1 --pipeline-parallel-size 4 --max-num-seqs 384
--disable-log-requests --dtype bfloat16 --gpu-memory-util 0.9
--disable-log-stats --num_scheduler_steps 1 --max-num-batched-tokens
2048 --max-model-len 256 --block-size 128`

Known issues:
* since for Pipeline Parallelism max_num_seqs acts as a microbatch for a
single virtual_engine - for bigger batch_size we fall into a very
specific corner case and get flat_pa error -> set batch_size to
approximately batch size that you would use in TP but divided by pp_size
* delayed sampling is not yet compatible with pipeline parallelism
* virtaul_engine ID is passed to HPUGraph which results in pp_size *
amount of graphs

Signed-off-by: jmaksymczuk <jmaksymczuk@habana.ai>
Co-authored-by: Rafal Litka <rlitka@habana.ai>
Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
czhu15 pushed a commit that referenced this pull request May 15, 2025
- Enable PP solution with full support for DeepSeek R1 execution with
PP>1.
- Requires 1.21.0 or newer. Does not support 1.20.1 or older.
- Implementation mirrors #1000
as closely as possible while ensuring DeepSeek R1 functions fully.
- Adds a benchmark script for sweeping various configs automatically.
This can be removed if you feel it shouldnt merge to deepseek_r1 branch.

Additional validation is being done by yabai.hu@intel.com.

@czhu15 youlei.yang@intel.com please help start the review in the
meantime.

Signed-off-by: Voas, Tanner <tanner.voas@intel.com>
Co-authored-by: Hu, Yabai <yabai.hu@intel.com>
Co-authored-by: Ji, Kunshang <kunshang.ji@intel.com>
Co-authored-by: Sheng, Yi <yi.sheng@intel.com>
Co-authored-by: Chen, Xinyu <xinyu1.chen@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants