Skip to content

Commit

Permalink
Adding more details on how to load things.
Browse files Browse the repository at this point in the history
- Loading with memmap
- Loading a sharded tensor
- Moved some snippets to `candle-examples/src/lib.rs` This is because
managing book specific dependencies is a pain rust-lang/mdBook#706
- This causes a non aligned inclusion  rust-lang/mdBook#1856 which we have
to ignore fmt to remove.

mdbook might need some more love :)
  • Loading branch information
Narsil committed Aug 1, 2023
1 parent def0e5a commit d62ea0b
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 13 deletions.
46 changes: 35 additions & 11 deletions candle-book/src/inference/hub.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ let weights = candle::safetensors::load(weights, &Device::Cpu);

We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.

You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)


## Using async

Expand All @@ -35,17 +37,9 @@ cargo add hf-hub --features tokio
```

```rust,ignore
# extern crate candle;
# extern crate hf_hub;
use hf_hub::api::tokio::Api;
use candle::Device;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights, &Device::Cpu);
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
```


Expand Down Expand Up @@ -78,3 +72,33 @@ let output = linear.forward(&input_ids);
```

For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.

## Memory mapping

For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)

**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
and will definitely be slower on network mounted disk, because it will issue more read calls.

```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
```

**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.


## Tensor Parallel Sharding

When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.

For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.

```bash
cargo add safetensors
```


```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
```
6 changes: 5 additions & 1 deletion candle-core/src/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {

pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
let st = safetensors::SafeTensors::deserialize(&data)?;
load_buffer(&data[..], device)
}

pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
let st = safetensors::SafeTensors::deserialize(data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
Expand Down
5 changes: 4 additions & 1 deletion candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ half = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true}
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
memmap2.workspace = true

[build-dependencies]
anyhow = { workspace = true }
Expand Down
99 changes: 99 additions & 0 deletions candle-examples/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
Ok(device)
}
}

#[cfg(test)]
mod tests {
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights_filename = repo.get("model.safetensors").await.unwrap();

let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}

#[rustfmt::skip]
#[test]
fn book_hub_2() {
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();

let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}

#[rustfmt::skip]
#[test]
fn book_hub_3() {
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();

let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };

// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();

// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];

if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;

// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.

let iterator = view.slice(start..stop).unwrap();

tp_shape[dim] = block_size;

// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();

// TODO: Implement from_buffer_iterator to we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
}

0 comments on commit d62ea0b

Please sign in to comment.