Skip to content

Commit

Permalink
Add candle CudaDevice and MetalDevice to avoid creating a new uni…
Browse files Browse the repository at this point in the history
…que device each time (#2290)

* Add candle CudaDevice and MetalDevice to avoid creating a new unique device each time

* Fix doc example

* Change enum usage
  • Loading branch information
laggui authored Sep 25, 2024
1 parent 37d8795 commit 112f09e
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 23 deletions.
4 changes: 2 additions & 2 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ macro_rules! bench_on_backend {
use burn::backend::candle::CandleDevice;
use burn::backend::Candle;

let device = CandleDevice::Cuda(0);
let device = CandleDevice::cuda(0);
bench::<Candle>(&device, feature_name, url, token);
}

Expand All @@ -123,7 +123,7 @@ macro_rules! bench_on_backend {
use burn::backend::candle::CandleDevice;
use burn::backend::Candle;

let device = CandleDevice::Metal(0);
let device = CandleDevice::metal(0);
bench::<Candle>(&device, feature_name, url, token);
}

Expand Down
104 changes: 91 additions & 13 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationStrategy},
Device,
};
use candle_core::DeviceLocation;
use candle_core::{backend::BackendDevice, DeviceLocation};

use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
Expand All @@ -16,7 +16,7 @@ use crate::{
///
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
#[derive(Clone, Copy, Default, Debug)]
#[derive(Clone, Default, Debug)]
pub struct Candle<F = f32, I = i64>
where
F: FloatCandleElement,
Expand All @@ -27,29 +27,89 @@ where
}

/// The device type for the candle backend.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
/// The device struct when using the `candle` backend.
///
/// Note that you need to provide the device index when using Cuda.
/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
/// ```no_run
/// use burn_candle::CandleDevice;
///
/// // Create a Cuda device from its index
/// let device = CandleDevice::cuda(0);
/// // Create a Metal device from its index
/// let device = CandleDevice::metal(0);
/// ```
pub enum CandleDevice {
/// CPU device.
Cpu,

/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(usize),
Cuda(CudaDevice),

/// Metal device with the given index. The index is the index of the Metal device in the list of
/// all Metal devices found on the system.
Metal(usize),
Metal(MetalDevice),
}

impl CandleDevice {
/// Create a Cuda device with the given index.
/// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
pub fn cuda(index: usize) -> Self {
CandleDevice::Cuda(CudaDevice {
device: candle_core::CudaDevice::new(index).unwrap(),
index,
})
}

/// Create a Metal device with the given index.
/// The index is the index of the Metal device in the list of all Metal devices found on the system.
pub fn metal(index: usize) -> Self {
CandleDevice::Metal(MetalDevice {
device: candle_core::MetalDevice::new(index).unwrap(),
index,
})
}
}

#[derive(Clone, Debug)]
/// A Cuda device for the `candle` backend.
pub struct CudaDevice {
pub(crate) device: candle_core::CudaDevice,
/// The index of the Cuda device in the list of all devices on the system.
pub index: usize,
}

impl PartialEq for CudaDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}

impl Eq for CudaDevice {}

#[derive(Clone, Debug)]
/// A Metal device for the `candle` backend.
pub struct MetalDevice {
pub(crate) device: candle_core::MetalDevice,
/// The index of the Metal device in the list of all devices on the system.
pub index: usize,
}

impl PartialEq for MetalDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}

impl Eq for MetalDevice {}

impl From<CandleDevice> for candle_core::Device {
fn from(device: CandleDevice) -> Self {
match device {
CandleDevice::Cpu => candle_core::Device::Cpu,
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
}
}
}
Expand All @@ -58,8 +118,26 @@ impl From<candle_core::Device> for CandleDevice {
fn from(device: candle_core::Device) -> Self {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
DeviceLocation::Cuda { gpu_id } => {
if let candle_core::Device::Cuda(device) = device {
CandleDevice::Cuda(CudaDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected CUDA device.");
}
}
DeviceLocation::Metal { gpu_id } => {
if let candle_core::Device::Metal(device) = device {
CandleDevice::Metal(MetalDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected Metal device.");
}
}
}
}
}
Expand All @@ -68,8 +146,8 @@ impl DeviceOps for CandleDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
CandleDevice::Cpu => DeviceId::new(0, 0),
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
CandleDevice::Cuda(device) => DeviceId::new(1, device.index as u32),
CandleDevice::Metal(device) => DeviceId::new(2, device.index as u32),
}
}
}
Expand Down Expand Up @@ -111,7 +189,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
fn sync(device: &Device<Self>, sync_type: SyncType) {
match sync_type {
SyncType::Wait => {
let device: candle_core::Device = (*device).into();
let device: candle_core::Device = (device.clone()).into();

match device {
candle_core::Device::Cpu => (),
Expand Down
8 changes: 5 additions & 3 deletions crates/burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub fn cat<E: CandleElement>(tensors: Vec<CandleTensor<E>>, dim: usize) -> Candl
}

pub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor<E> {
CandleTensor::from_data(data, *device)
CandleTensor::from_data(data, device.clone())
}
pub fn into_data<E: CandleElement>(tensor: CandleTensor<E>) -> TensorData {
TensorData::new(
Expand All @@ -28,11 +28,13 @@ pub fn to_device<E: CandleElement>(
tensor: CandleTensor<E>,
device: &CandleDevice,
) -> CandleTensor<E> {
CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap())
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
}

pub fn empty<E: CandleElement>(shape: Shape, device: &CandleDevice) -> CandleTensor<E> {
CandleTensor::new(candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(*device).into()).unwrap())
CandleTensor::new(
candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(device.clone()).into()).unwrap(),
)
}

pub fn swap_dims<E: CandleElement>(
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F

fn int_zeros(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(*device).into()).unwrap(),
candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
)
}

fn int_ones(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::ones(shape.dims, I::DTYPE, &(*device).into()).unwrap(),
candle_core::Tensor::ones(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
)
}

Expand Down Expand Up @@ -324,7 +324,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
device: &Device<Self>,
) -> IntTensor<Self> {
let shape = shape.dims;
let device = &(*device).into();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::base::{expand, permute, sign};

impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor<F> {
CandleTensor::from_data(data, *device)
CandleTensor::from_data(data, device.clone())
}

fn float_random(
Expand All @@ -24,7 +24,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
device: &Device<Self>,
) -> FloatTensor<Self> {
let shape = shape.dims;
let device = &(*device).into();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)
Expand Down

0 comments on commit 112f09e

Please sign in to comment.