From e04dd4fea568bcc6ee5e3f536a57488cdba0d772 Mon Sep 17 00:00:00 2001 From: optman Date: Fri, 26 Jan 2024 03:46:30 +0800 Subject: [PATCH] make non-const conv1d available to stable rust (#911) * make device capable of Conv1DKernel * make non-const conv1d available to stable rust --------- Co-authored-by: Corey Lowman --- dfdx-core/src/tensor_ops/conv1d/mod.rs | 5 +++-- dfdx-core/src/tensor_ops/mod.rs | 2 -- dfdx-core/src/tensor_ops/utilities/device.rs | 3 +++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dfdx-core/src/tensor_ops/conv1d/mod.rs b/dfdx-core/src/tensor_ops/conv1d/mod.rs index db51c02b..5fd8e23b 100644 --- a/dfdx-core/src/tensor_ops/conv1d/mod.rs +++ b/dfdx-core/src/tensor_ops/conv1d/mod.rs @@ -9,7 +9,7 @@ mod tests; #[repr(C)] #[derive(Debug, Copy, Clone)] -pub(super) struct Conv1DOp { +pub struct Conv1DOp { pub kernel: usize, pub stride: usize, pub padding: usize, @@ -22,7 +22,7 @@ pub(super) struct Conv1DOp { pub l_out: usize, } -pub(super) trait Conv1DKernel: Storage { +pub trait Conv1DKernel: Storage { fn alloc(&self, s: S) -> Result, Error>; fn forward( @@ -108,6 +108,7 @@ pub trait TryConv1D: Sized { ) -> Result; } +#[cfg(feature = "nightly")] impl< const KERNEL: usize, const STRIDE: usize, diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index d934b678..453457f4 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -281,9 +281,7 @@ pub use upscale2d::{ }; pub use var_to::VarTo; -#[cfg(feature = "nightly")] mod conv1d; -#[cfg(feature = "nightly")] pub use conv1d::TryConv1D; #[cfg(feature = "nightly")] diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 00fa9502..8cbc2137 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -112,6 +112,9 @@ pub trait Device: + BinaryKernel + BinaryKernel + crate::tensor_ops::axpy::AxpyKernel + + // conv1d + + super::super::conv1d::Conv1DKernel { }