Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-node support for tensor parallelism #1125

Merged
merged 32 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e11f166
Multi-node support, server?
EricLBuehler Feb 7, 2025
8eed926
Fix check
EricLBuehler Feb 7, 2025
82e801a
Remove check
EricLBuehler Feb 7, 2025
0c926aa
Remove check
EricLBuehler Feb 7, 2025
3777371
Debug
EricLBuehler Feb 7, 2025
3759ce6
Debug
EricLBuehler Feb 7, 2025
a9d7d98
counter+=1
EricLBuehler Feb 7, 2025
1700f2a
counter+=1
EricLBuehler Feb 7, 2025
9425b29
Replicate kv heads?
EricLBuehler Feb 9, 2025
3ee2560
Replicate kv heads?
EricLBuehler Feb 9, 2025
1c988d1
Replicate kv heads?
EricLBuehler Feb 9, 2025
05e618b
Shard with specific offset
EricLBuehler Feb 9, 2025
73c68e0
Shard with specific offset
EricLBuehler Feb 9, 2025
dcfa231
Replicate kv heads
EricLBuehler Feb 9, 2025
77fcc1c
Fix num kv groups
EricLBuehler Feb 9, 2025
f0f967c
Fix num kv groups
EricLBuehler Feb 9, 2025
81fca45
Debug
EricLBuehler Feb 10, 2025
e71c7d6
Refactor to client and server
EricLBuehler Feb 11, 2025
cc005d8
Use multi-node synchronization
EricLBuehler Feb 11, 2025
92731f5
Debugging
EricLBuehler Feb 11, 2025
c866716
Hierarchical synchronization design
EricLBuehler Feb 11, 2025
5e2f4ad
It works!
EricLBuehler Feb 11, 2025
ac38e05
Add some docs
EricLBuehler Feb 11, 2025
1a8f376
Add logging
EricLBuehler Feb 11, 2025
0d2c240
Add logging
EricLBuehler Feb 11, 2025
8159b91
Update docs
EricLBuehler Feb 11, 2025
4bddcc8
Try to reuse socket
EricLBuehler Feb 11, 2025
f163ab4
Try to reuse socket
EricLBuehler Feb 11, 2025
4893230
Maybe its faster
EricLBuehler Feb 11, 2025
00a58eb
Minimize tcp traffic
EricLBuehler Feb 11, 2025
659aa25
Clippy
EricLBuehler Feb 11, 2025
a24c431
Set some timeouts
EricLBuehler Feb 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 80 additions & 80 deletions Cargo.lock

Large diffs are not rendered by default.

39 changes: 38 additions & 1 deletion docs/DISTRIBUTED.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Distributed inference in mistral.rs
# Distributed inference in mistral.rs: Tensor parallelism and Multi-node support

Mistral.rs supports distributed inference on CUDA with Tensor Parallelism via NCCL.

Expand All @@ -16,3 +16,40 @@ See the following environment variables:
|--|--|--|
|`MISTRALRS_NO_NCCL=1`|Disable TP and NCCL|If the model does not fit on the available CUDA devices, disabling NCCL will re-enable automatic device mapping|
|`MISTRALRS_PIPELINE_PARALLEL=<number> (default: 1 = disabled)`|Parallelize the model along the layers in addition to the GPUs|Increasing this value is useful for tuning performance on a model-specific basis. It does not change the number of GPUs required, but can help when the single-node interconnects are a bottleneck.|

## Multi-node support

```
# Head node:
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_HEAD_NUM_WORKERS=1 MISTRALRS_MN_HEAD_PORT=<PORT> cargo run --release --features cuda -- -i plain -m ...

# For the worker nodes:
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_WORKER_ID=0 MISTRALRS_WORKER_SERVER_ADDR=<HEAD ADDR>:<PORT> cargo run --release --features cuda -- -i plain -m ...
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_WORKER_ID=1 MISTRALRS_WORKER_SERVER_ADDR=<HEAD ADDR>:<PORT> cargo run --release --features cuda -- -i plain -m ...
MISTRALRS_MN_GLOBAL_WORLD_SIZE=32 MISTRALRS_MN_WORKER_ID=2 MISTRALRS_WORKER_SERVER_ADDR=<HEAD ADDR>:<PORT> cargo run --release --features cuda -- -i plain -m ...
```

Multi-node support in mistral.rs divides the nodes into two groups: a "head" node, and multiple "worker" nodes. Head node choice is arbitrary.
For example, if a system has 8 nodes, there will be 1 "head" node, and 7 "worker" nodes.

To enable multi-node, set the `MISTRALRS_MN_GLOBAL_WORLD_SIZE=<number>` environment variable to the total number of GPUs in all nodes, including "head" and "worker"s.

> Note: `MISTRALRS_PIPELINE_PARALLEL` is incompatible with multi-node (setting `MISTRALRS_MN_GLOBAL_WORLD_SIZE`)

It is recommended to use server mode with mistral.rs when in multi-node. **Currently, you must send requests to every node!**

The following environment variables must be set for each node:

**Head node:**

|Name|Function|Usage|
|--|--|--|
|`MISTRALRS_MN_HEAD_NUM_WORKERS=<number>`|The number of worker nodes which will be connected.|This should be the number of nodes in the system, minus 1 for the head node.|
|`MISTRALRS_MN_HEAD_PORT=<PORT>`|The port on which to communicate with the worker nodes.|Worker nodes will connect to this port via TCP sockets|

**Worker node:**

|Name|Function|Usage|
|--|--|--|
|`MISTRALRS_MN_WORKER_ID=<number>`|The 0-indexed worker ID for this worker node.|If there are 4 nodes (1 head, 3 workers), then the worker ids will be 0, 1, and 2|
|`MISTRALRS_MN_WORKER_SERVER_ADDR=<ADDR>:<PORT>`|The IP address and port to connect to the server.|This is used to establish communication with the head node.|
31 changes: 27 additions & 4 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer, Shard,
ShardedVarBuilder,
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -193,20 +193,39 @@ impl CausalSelfAttention {
comm,
vb.pp("q_proj"),
)?;
let k_proj = ColumnParallelLayer::new(

// We may need to replicate the kv heads
let kv_replicate = if comm.world_size() > cfg.num_key_value_heads {
comm.world_size() / cfg.num_key_value_heads
} else {
1
};

let kv_shard_id = comm.rank() / kv_replicate;
// let kv_block_size = size_kv / comm.world_size();
let kv_block_size = cfg.hidden_size / cfg.num_attention_heads;
let shard = Shard::Offset {
dim: 0,
offset: kv_shard_id * kv_block_size,
len: kv_block_size,
};

let k_proj = ColumnParallelLayer::new_with_shard(
size_in,
size_kv,
&cfg.quantization_config,
false,
comm,
shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new(
let v_proj = ColumnParallelLayer::new_with_shard(
size_in,
size_kv,
&cfg.quantization_config,
false,
comm,
shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
Expand All @@ -229,7 +248,11 @@ impl CausalSelfAttention {
max_seq_len: cfg.max_position_embeddings,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
n_kv_groups: if kv_replicate != 1 {
(cfg.num_attention_heads / cfg.num_key_value_heads) / kv_replicate
} else {
cfg.num_attention_heads / cfg.num_key_value_heads
},
use_flash_attn: cfg.use_flash_attn,
softcap: None,
softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
Expand Down
103 changes: 93 additions & 10 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
};
use anyhow::Result;
use anyhow::{Context, Result};
use candle_core::{Device, Tensor, Var};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use indicatif::MultiProgress;
Expand All @@ -51,12 +51,12 @@ use rayon::iter::{
use regex_automata::meta::Regex;
use std::any::Any;
use std::borrow::Cow;
use std::fs;
use std::num::{NonZero, NonZeroUsize};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::{Arc, Barrier, RwLock};
use std::time::Instant;
use std::{env, fs};
use tokenizers::Tokenizer;
use tokio::sync::Mutex;
use tracing::{info, warn};
Expand Down Expand Up @@ -300,7 +300,8 @@ impl Loader for NormalLoader {

let use_nccl = available_devices.iter().all(|dev| dev.is_cuda())
&& available_devices.len() > 1
&& std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1");
&& (std::env::var("MISTRALRS_NO_NCCL").is_err()
|| std::env::var("MISTRALRS_NO_NCCL").is_ok_and(|x| x != "1"));

// If auto, convert to Map if not using nccl
if use_nccl {
Expand Down Expand Up @@ -476,15 +477,63 @@ impl Loader for NormalLoader {
anyhow::bail!("MISTRALRS_PIPELINE_PARALLEL must be nonzero")
}

let world_size = available_devices.len() / pipeline_parallel_size;
let local_world_size = available_devices.len() / pipeline_parallel_size;
let global_world_size = if let Ok(x) = std::env::var("MISTRALRS_MN_GLOBAL_WORLD_SIZE") {
usize::from_str(&x).context("MISTRALRS_MN_GLOBAL_WORLD_SIZE")?
} else {
local_world_size
};

let use_multi_node = global_world_size != local_world_size;
if use_multi_node {
info!("Global world size != local world size, entering multi-node.");
}

info!("Tensor parallel world size is {world_size}");
if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
}

info!("Local tensor parallel world size is {local_world_size}");
info!("Global tensor parallel world size is {global_world_size}");
info!("Pipeline parallelism size is {pipeline_parallel_size}");

let ids = (0..pipeline_parallel_size)
let mut ids = (0..pipeline_parallel_size)
.map(|_| mistralrs_quant::Id::new())
.collect::<Vec<_>>();

if ids.len() != 1 {
anyhow::bail!(
"MISTRALRS_PIPELINE_PARALLEL cannot be set at the same time as MISTRALRS_MN_GLOBAL_WORLD_SIZE; multi-node is incompatible with pipeline parallel."
);
}

if use_multi_node {
let id = &mut ids[0];
if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
let n_nodes =
usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
info!("Head node managing {n_nodes} workers.");
let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
anyhow::bail!(
"Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT"
);
};
info!("Head node initializing connection on {port}.");
let server = mistralrs_quant::Server::new(
&format!("0.0.0.0:{port}"),
n_nodes,
local_world_size,
)?;

server.broadcast_id(id)?;
} else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
info!("Worker node connecting to {addr}.");
let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;

*id = client.receive_id()?;
}
}

if available_devices.len() % ids.len() != 0 {
anyhow::bail!(
"Pipeline parallel size {} must divide the number of available devices {}",
Expand All @@ -497,13 +546,47 @@ impl Loader for NormalLoader {
.chunks(available_devices.len() / pipeline_parallel_size)
.collect::<Vec<_>>();

let rank_offset = if env::var("MISTRALRS_MN_WORKER_SERVER_ADDR").is_ok() {
let Ok(node_id) = env::var("MISTRALRS_MN_WORKER_ID") else {
anyhow::bail!(
"Got MISTRALRS_MN_WORKER_SERVER_ADDR, expected MISTRALRS_MN_WORKER_ID"
);
};
let node_id = usize::from_str(&node_id).context("MISTRALRS_MN_WORKER_ID")?;
info!("Worker ID is {node_id}.");
(node_id + 1) * local_world_size
} else {
0
};

// Transpose
let mut comms_all = Vec::new();
for (pipeline_parallel_i, devices_per_pipeline_parallel) in
split_available_devices.iter().enumerate()
{
// Each pipeline parallel gets its own barrier
let barrier = Arc::new(Barrier::new(world_size));
let barrier = if let Ok(n_nodes) = env::var("MISTRALRS_MN_HEAD_NUM_WORKERS") {
let n_nodes =
usize::from_str(&n_nodes).context("MISTRALRS_MN_HEAD_NUM_WORKERS")?;
let Ok(port) = env::var("MISTRALRS_MN_HEAD_PORT") else {
anyhow::bail!(
"Got MISTRALRS_MN_HEAD_NUM_WORKERS, expected MISTRALRS_MN_HEAD_PORT"
);
};
let server = mistralrs_quant::Server::new(
&format!("0.0.0.0:{port}"),
n_nodes,
local_world_size,
)?;

Arc::new(server) as Arc<dyn mistralrs_quant::BarrierLike>
} else if let Ok(addr) = env::var("MISTRALRS_MN_WORKER_SERVER_ADDR") {
let client = mistralrs_quant::Client::new(addr.parse()?, local_world_size)?;
Arc::new(client) as Arc<dyn mistralrs_quant::BarrierLike>
} else {
Arc::new(Barrier::new(local_world_size))
as Arc<dyn mistralrs_quant::BarrierLike>
};

// They each block on each other
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
Expand All @@ -522,8 +605,8 @@ impl Loader for NormalLoader {
mistralrs_quant::Comm::from_device(
ids[pipeline_parallel_i],
device,
rank,
world_size,
rank + rank_offset,
global_world_size,
barrier.clone(),
)
})
Expand All @@ -539,7 +622,7 @@ impl Loader for NormalLoader {

// row major: number of ranks x pipeline parallel
// Also corresponds to the device for that comm for the
let comms = (0..world_size)
let comms = (0..local_world_size)
.map(|pipeline_parallel_i| {
comms_all
.iter()
Expand Down
25 changes: 19 additions & 6 deletions mistralrs-quant/src/distributed/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
};

fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
Shard {
Shard::Simple {
dim,
rank,
world_size,
Expand Down Expand Up @@ -195,18 +195,15 @@ pub struct ColumnParallelLayer {

impl ColumnParallelLayer {
#[allow(clippy::new_ret_no_self)]
pub fn new(
pub fn new_with_shard(
in_dim: usize,
out_dim: usize,
config: &Option<QuantizedConfig>,
bias: bool,
comm: &Arc<crate::Comm>,
shard: Shard,
vb: ShardedVarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
let rank = comm.rank();
let world_size = comm.world_size();
let shard = shard(0, rank, world_size);

let weight = if let Some(quant_conf) = &config {
// GPTQ and BNB do not support tensor parallelism
if matches!(
Expand Down Expand Up @@ -259,6 +256,22 @@ impl ColumnParallelLayer {

Ok(Arc::new(Self { weight, bias }))
}

#[allow(clippy::new_ret_no_self)]
pub fn new(
in_dim: usize,
out_dim: usize,
config: &Option<QuantizedConfig>,
bias: bool,
comm: &Arc<crate::Comm>,
vb: ShardedVarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
let rank = comm.rank();
let world_size = comm.world_size();
let shard = shard(0, rank, world_size);

Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
}
}

impl QuantMethod for ColumnParallelLayer {
Expand Down
Loading
Loading