diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs index ef72f43..f4a3398 100644 --- a/native/candlex/src/lib.rs +++ b/native/candlex/src/lib.rs @@ -1,7 +1,8 @@ mod atoms { rustler::atoms! { cpu, - cuda + cuda, + metal } } diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 78c94b4..627c4f4 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -26,6 +26,7 @@ impl ExTensor { let dev_string = match tensor.device() { Device::Cpu => atoms::cpu(), Device::Cuda(_) => atoms::cuda(), + Device::Metal(_) => atoms::metal(), }; Self {