Skip to content

Commit bfe6681

Browse files
nrghoshSheldonTsen
authored andcommitted
[docs][serve][llm] examples and doc for cross-node TP/PP in Serve (ray-project#57715)
Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com> Signed-off-by: Nikhil G <nrghosh@users.noreply.github.com>
1 parent 87c5b05 commit bfe6681

File tree

5 files changed

+308
-0
lines changed

5 files changed

+308
-0
lines changed

doc/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ py_test_run_all_subdirectory(
299299
"source/serve/doc_code/stable_diffusion.py",
300300
"source/serve/doc_code/object_detection.py",
301301
"source/serve/doc_code/vllm_example.py",
302+
"source/serve/doc_code/cross_node_parallelism_example.py",
302303
"source/serve/doc_code/llm/llm_yaml_config_example.py",
303304
"source/serve/doc_code/llm/qwen_example.py",
304305
],
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# flake8: noqa
2+
"""
3+
Cross-node parallelism examples for Ray Serve LLM.
4+
5+
TP / PP / custom placement group strategies
6+
for multi-node LLM deployments.
7+
"""
8+
9+
# __cross_node_tp_example_start__
10+
import vllm
11+
from ray import serve
12+
from ray.serve.llm import LLMConfig, build_openai_app
13+
14+
# Configure a model with tensor parallelism across 2 GPUs
15+
# Tensor parallelism splits model weights across GPUs
16+
llm_config = LLMConfig(
17+
model_loading_config=dict(
18+
model_id="llama-3.1-8b",
19+
model_source="meta-llama/Llama-3.1-8B-Instruct",
20+
),
21+
deployment_config=dict(
22+
autoscaling_config=dict(
23+
min_replicas=1,
24+
max_replicas=2,
25+
)
26+
),
27+
accelerator_type="L4",
28+
engine_kwargs=dict(
29+
tensor_parallel_size=2,
30+
max_model_len=8192,
31+
),
32+
)
33+
34+
# Deploy the application
35+
app = build_openai_app({"llm_configs": [llm_config]})
36+
serve.run(app, blocking=True)
37+
# __cross_node_tp_example_end__
38+
39+
# __cross_node_pp_example_start__
40+
from ray import serve
41+
from ray.serve.llm import LLMConfig, build_openai_app
42+
43+
# Configure a model with pipeline parallelism across 2 GPUs
44+
# Pipeline parallelism splits model layers across GPUs
45+
llm_config = LLMConfig(
46+
model_loading_config=dict(
47+
model_id="llama-3.1-8b",
48+
model_source="meta-llama/Llama-3.1-8B-Instruct",
49+
),
50+
deployment_config=dict(
51+
autoscaling_config=dict(
52+
min_replicas=1,
53+
max_replicas=1,
54+
)
55+
),
56+
accelerator_type="L4",
57+
engine_kwargs=dict(
58+
pipeline_parallel_size=2,
59+
max_model_len=8192,
60+
),
61+
)
62+
63+
# Deploy the application
64+
app = build_openai_app({"llm_configs": [llm_config]})
65+
serve.run(app, blocking=True)
66+
# __cross_node_pp_example_end__
67+
68+
# __cross_node_tp_pp_example_start__
69+
from ray import serve
70+
from ray.serve.llm import LLMConfig, build_openai_app
71+
72+
# Configure a model with both tensor and pipeline parallelism
73+
# This example uses 4 GPUs total (2 TP * 2 PP)
74+
llm_config = LLMConfig(
75+
model_loading_config=dict(
76+
model_id="llama-3.1-8b",
77+
model_source="meta-llama/Llama-3.1-8B-Instruct",
78+
),
79+
deployment_config=dict(
80+
autoscaling_config=dict(
81+
min_replicas=1,
82+
max_replicas=1,
83+
)
84+
),
85+
accelerator_type="L4",
86+
engine_kwargs=dict(
87+
tensor_parallel_size=2,
88+
pipeline_parallel_size=2,
89+
max_model_len=8192,
90+
enable_chunked_prefill=True,
91+
max_num_batched_tokens=4096,
92+
),
93+
)
94+
95+
# Deploy the application
96+
app = build_openai_app({"llm_configs": [llm_config]})
97+
serve.run(app, blocking=True)
98+
# __cross_node_tp_pp_example_end__
99+
100+
# __custom_placement_group_pack_example_start__
101+
from ray import serve
102+
from ray.serve.llm import LLMConfig, build_openai_app
103+
104+
# Configure a model with custom placement group using PACK strategy
105+
# PACK tries to place workers on as few nodes as possible for locality
106+
llm_config = LLMConfig(
107+
model_loading_config=dict(
108+
model_id="llama-3.1-8b",
109+
model_source="meta-llama/Llama-3.1-8B-Instruct",
110+
),
111+
deployment_config=dict(
112+
autoscaling_config=dict(
113+
min_replicas=1,
114+
max_replicas=1,
115+
)
116+
),
117+
accelerator_type="L4",
118+
engine_kwargs=dict(
119+
tensor_parallel_size=2,
120+
max_model_len=8192,
121+
),
122+
placement_group_config=dict(
123+
bundles=[{"GPU": 1}] * 2,
124+
strategy="PACK",
125+
),
126+
)
127+
128+
# Deploy the application
129+
app = build_openai_app({"llm_configs": [llm_config]})
130+
serve.run(app, blocking=True)
131+
# __custom_placement_group_pack_example_end__
132+
133+
# __custom_placement_group_spread_example_start__
134+
from ray import serve
135+
from ray.serve.llm import LLMConfig, build_openai_app
136+
137+
# Configure a model with custom placement group using SPREAD strategy
138+
# SPREAD distributes workers across nodes for fault tolerance
139+
llm_config = LLMConfig(
140+
model_loading_config=dict(
141+
model_id="llama-3.1-8b",
142+
model_source="meta-llama/Llama-3.1-8B-Instruct",
143+
),
144+
deployment_config=dict(
145+
autoscaling_config=dict(
146+
min_replicas=1,
147+
max_replicas=1,
148+
)
149+
),
150+
accelerator_type="L4",
151+
engine_kwargs=dict(
152+
tensor_parallel_size=4,
153+
max_model_len=8192,
154+
),
155+
placement_group_config=dict(
156+
bundles=[{"GPU": 1}] * 4,
157+
strategy="SPREAD",
158+
),
159+
)
160+
161+
# Deploy the application
162+
app = build_openai_app({"llm_configs": [llm_config]})
163+
serve.run(app, blocking=True)
164+
# __custom_placement_group_spread_example_end__
165+
166+
# __custom_placement_group_strict_pack_example_start__
167+
from ray import serve
168+
from ray.serve.llm import LLMConfig, build_openai_app
169+
170+
# Configure a model with custom placement group using STRICT_PACK strategy
171+
# STRICT_PACK ensures all workers are placed on the same node
172+
llm_config = LLMConfig(
173+
model_loading_config=dict(
174+
model_id="llama-3.1-8b",
175+
model_source="meta-llama/Llama-3.1-8B-Instruct",
176+
),
177+
deployment_config=dict(
178+
autoscaling_config=dict(
179+
min_replicas=1,
180+
max_replicas=2,
181+
)
182+
),
183+
accelerator_type="A100",
184+
engine_kwargs=dict(
185+
tensor_parallel_size=2,
186+
max_model_len=8192,
187+
),
188+
placement_group_config=dict(
189+
bundles=[{"GPU": 1}] * 2,
190+
strategy="STRICT_PACK",
191+
),
192+
)
193+
194+
# Deploy the application
195+
app = build_openai_app({"llm_configs": [llm_config]})
196+
serve.run(app, blocking=True)
197+
# __custom_placement_group_strict_pack_example_end__

doc/source/serve/llm/quick-start.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,11 @@ serve run config.yaml
270270

271271
For monitoring and observability, see {doc}`Observability <user-guides/observability>`.
272272

273+
## Advanced usage patterns
274+
275+
For each usage pattern, Ray Serve LLM provides a server and client code snippet.
276+
277+
### Cross-node parallelism
278+
279+
Ray Serve LLM supports cross-node tensor parallelism (TP) and pipeline parallelism (PP), allowing you to distribute model inference across multiple GPUs and nodes. See {doc}`Cross-node parallelism <user-guides/cross-node-parallelism>` for a comprehensive guide on configuring and deploying models with cross-node parallelism.
280+
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
(cross-node-parallelism)=
2+
# Cross-node parallelism
3+
4+
Ray Serve LLM supports cross-node tensor parallelism (TP) and pipeline parallelism (PP), allowing you to distribute model inference across multiple GPUs and nodes. This capability enables you to:
5+
6+
- Deploy models that don't fit on a single GPU or node.
7+
- Scale model serving across your cluster's available resources.
8+
- Leverage Ray's placement group strategies to control worker placement for performance or fault tolerance.
9+
10+
::::{note}
11+
By default, Ray Serve LLM uses the `PACK` placement strategy, which tries to place workers on as few nodes as possible. If workers can't fit on a single node, they automatically spill to other nodes. This enables cross-node deployments when single-node resources are insufficient.
12+
::::
13+
14+
## Tensor parallelism
15+
16+
Tensor parallelism splits model weights across multiple GPUs, with each GPU processing a portion of the model's tensors for each forward pass. This approach is useful for models that don't fit on a single GPU.
17+
18+
The following example shows how to configure tensor parallelism across 2 GPUs:
19+
20+
::::{tab-set}
21+
22+
:::{tab-item} Python
23+
:sync: python
24+
25+
```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
26+
:language: python
27+
:start-after: __cross_node_tp_example_start__
28+
:end-before: __cross_node_tp_example_end__
29+
```
30+
:::
31+
32+
::::
33+
34+
## Pipeline parallelism
35+
36+
Pipeline parallelism splits the model's layers across multiple GPUs, with each GPU processing a subset of the model's layers. This approach is useful for very large models where tensor parallelism alone isn't sufficient.
37+
38+
The following example shows how to configure pipeline parallelism across 2 GPUs:
39+
40+
::::{tab-set}
41+
42+
:::{tab-item} Python
43+
:sync: python
44+
45+
```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
46+
:language: python
47+
:start-after: __cross_node_pp_example_start__
48+
:end-before: __cross_node_pp_example_end__
49+
```
50+
:::
51+
52+
::::
53+
54+
## Combined tensor and pipeline parallelism
55+
56+
For extremely large models, you can combine both tensor and pipeline parallelism. The total number of GPUs is the product of `tensor_parallel_size` and `pipeline_parallel_size`.
57+
58+
The following example shows how to configure a model with both TP and PP (4 GPUs total):
59+
60+
::::{tab-set}
61+
62+
:::{tab-item} Python
63+
:sync: python
64+
65+
```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
66+
:language: python
67+
:start-after: __cross_node_tp_pp_example_start__
68+
:end-before: __cross_node_tp_pp_example_end__
69+
```
70+
:::
71+
72+
::::
73+
74+
## Custom placement groups
75+
76+
You can customize how Ray places vLLM engine workers across nodes through the `placement_group_config` parameter. This parameter accepts a dictionary with `bundles` (a list of resource dictionaries) and `strategy` (placement strategy).
77+
78+
Ray Serve LLM uses the `PACK` strategy by default, which tries to place workers on as few nodes as possible. If workers can't fit on a single node, they automatically spill to other nodes. For more details on all available placement strategies, see {ref}`Ray Core's placement strategies documentation <pgroup-strategy>`.
79+
80+
::::{note}
81+
Data parallel deployments automatically override the placement strategy to `STRICT_PACK` because each replica must be co-located for correct data parallel behavior.
82+
::::
83+
84+
While you can specify the degree of tensor and pipeline parallelism, the specific assignment of model ranks to GPUs is managed by the vLLM engine and can't be directly configured through the Ray Serve LLM API. Ray Serve automatically injects accelerator type labels into bundles and merges the first bundle with replica actor resources (CPU, GPU, memory).
85+
86+
The following example shows how to use the `SPREAD` strategy to distribute workers across multiple nodes for fault tolerance:
87+
88+
::::{tab-set}
89+
90+
:::{tab-item} Python
91+
:sync: python
92+
93+
```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
94+
:language: python
95+
:start-after: __custom_placement_group_spread_example_start__
96+
:end-before: __custom_placement_group_spread_example_end__
97+
```
98+
:::
99+
100+
::::
101+

doc/source/serve/llm/user-guides/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ How-to guides for deploying and configuring Ray Serve LLM features.
55
```{toctree}
66
:maxdepth: 1
77
8+
Cross-node parallelism <cross-node-parallelism>
89
Data parallel attention <data-parallel-attention>
910
Deployment Initialization <deployment-initialization>
1011
Prefill/decode disaggregation <prefill-decode>

0 commit comments

Comments
 (0)