diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 5309ef7c..0f6cd5c6 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -35,7 +35,7 @@ num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } diff --git a/dfdx-core/src/data/collate.rs b/dfdx-core/src/data/collate.rs index d38a2a67..5f52d636 100644 --- a/dfdx-core/src/data/collate.rs +++ b/dfdx-core/src/data/collate.rs @@ -55,6 +55,7 @@ impl Collate for Vec<(A, B)> { impl<'a, A, B> Collate for Vec<&'a (A, B)> { type Collated = (Vec<&'a A>, Vec<&'a B>); fn collated(self) -> Self::Collated { + #[allow(clippy::map_identity)] self.into_iter().map(|(a, b)| (a, b)).unzip() } } diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 31e61643..c126db2c 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -128,44 +128,6 @@ pub mod prelude { pub use crate::tensor_ops::*; } -/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn flush_denormals_to_zero() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } -} - -/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn keep_denormals() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } -} - #[cfg(test)] pub(crate) mod tests { pub use num_traits::{Float, NumCast, Zero}; diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index 86974ec6..d24e2e32 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -153,7 +153,7 @@ impl> Gradients { #[inline] pub(crate) fn many_and_ref( &mut self, - ls: &Vec>, + ls: &[impl Tensorlike], r: &impl Tensorlike, ) -> (Vec<&mut D::Vec>, &D::Vec) { for i in 0..ls.len() { diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..a649196c 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -184,6 +184,7 @@ mod mul; mod nans_to; mod negate; mod normalize; +mod normalize_rms; pub(super) mod optim; mod permute_to; mod pow; @@ -251,6 +252,7 @@ pub use mul::{mul, TryMul}; pub use nans_to::nans_to; pub use negate::negate; pub use normalize::normalize; +pub use normalize_rms::normalize_rms; pub use optim::*; pub use permute_to::PermuteTo; pub use pow::{powf, powi}; diff --git a/dfdx-core/src/tensor_ops/normalize_rms.rs b/dfdx-core/src/tensor_ops/normalize_rms.rs new file mode 100644 index 00000000..eb70302a --- /dev/null +++ b/dfdx-core/src/tensor_ops/normalize_rms.rs @@ -0,0 +1,136 @@ +use crate::{ + shapes::{Axes, Dtype, ReduceShape, Shape}, + tensor::{Error, Tape, Tensor}, +}; + +use super::{BroadcastTo, Device, MeanTo, TryAdd, TryMul}; + +/// Normalizes `t` to have stddev `1.0` along `Ax`. `epsilon` is used during stddev. +/// Computes `t / (t.square().mean() + epsilon).sqrt()`. +/// +/// Normalizing a single axis: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t: Tensor, f32, _> = dev.zeros(); +/// let _ = t.normalize_rms::>(1e-5); +/// ``` +pub fn normalize_rms< + Ax: Axes, + S: Shape + ReduceShape, + E: Dtype, + D: Device, + T: Tape, +>( + t: Tensor, + epsilon: impl Into, +) -> Tensor { + t.normalize_rms::(epsilon) +} + +impl, T: Tape> Tensor { + /// See [normalize_rms] + pub fn normalize_rms(self, epsilon: impl Into) -> Self + where + S: ReduceShape, + { + self.try_normalize_rms::(epsilon).unwrap() + } + + /// See [normalize_rms] + pub fn try_normalize_rms(self, epsilon: impl Into) -> Result + where + S: ReduceShape, + { + let shape = self.shape; + let sq = self.retaped::().try_square()?; + let sq_mean = sq.try_mean::<_, Ax>()?; + let rsqrt = sq_mean + .try_add(epsilon)? + .try_sqrt()? + .try_recip()? + .try_broadcast_like(&shape)?; + self.try_mul(rsqrt) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::*; + use crate::{shapes::*, tensor::*, tensor_ops::*}; + + #[test] + fn test_1d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a = dev.tensor([-2.0, 0.0, 5.0]).to_dtype::(); + let r = a.leaky_trace().normalize_rms(1e-5); + assert_close_to_literal!(&r, [-0.64326715, 0.0, 1.6081679]); + // NOTE: .exp() so we can make sure normalize is using result grad properly + let g = r.exp().mean().backward(); + assert_close_to_literal!(&g.get(&a), [0.23318729, 0.107211195, 0.09327549]); + } + + #[test] + fn test_2d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a = dev + .tensor([[-2.0, 0.0, 5.0], [1.0, 2.0, 3.0]]) + .to_dtype::(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!( + r, + [ + [-0.64326715, 0.0, 1.6081679], + [0.46290955, 0.9258191, 1.3887286] + ] + ); + let g = r.exp().mean().backward(); + assert_close_to_literal!( + g.get(&a), + [ + [0.116593644, 0.053605597, 0.046637744], + [0.019706108, -0.011002079, 0.0007670224] + ] + ); + } + + #[test] + fn test_2d_normalize_rms_axis_first() { + let dev: TestDevice = Default::default(); + let a = dev + .tensor([[-2.0, 0.0], [1.0, 2.0], [4.0, 5.0]]) + .to_dtype::(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!( + r, + [ + [-0.7559284, 0.0], + [0.3779642, 0.64326715], + [1.5118568, 1.6081679] + ] + ); + let g = r.exp().mean().backward(); + assert_close_to_literal!( + g.get(&a), + [ + [0.14153406, 0.053605597], + [0.03595103, -0.0043795705], + [0.061779693, 0.0017521679] + ] + ); + } + + #[test] + fn test_3d_normalize_rms_axis_last() { + let dev: TestDevice = Default::default(); + let a: Tensor, TestDtype, _> = dev.ones(); + let r = a.leaky_trace().normalize_rms::>(1e-5); + assert_close_to_literal!(r, [[[1.0; 3]; 2]; 4], 1e-5); + let g = r.exp().mean().backward(); + assert_close_to_literal!(g.get(&a), [[[0.0; 3]; 2]; 4], 1e-5); + } +} + +// Implementation references: +// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328 +// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222 diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..91f87cf6 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -114,25 +114,49 @@ pub trait Device: + crate::tensor_ops::axpy::AxpyKernel // conv1d - + super::super::conv1d::Conv1DKernel + + NonCudnnCuda +{ +} + +#[cfg(feature = "cudnn")] +pub trait NonCudnnCuda {} + +#[cfg(not(feature = "cudnn"))] +pub trait NonCudnnCuda: + // conv1d + super::super::conv1d::Conv1DKernel { } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} -#[cfg(feature = "f16")] -impl Device> for crate::tensor::Cpu {} +mod f16_ { + use super::*; + impl Device for crate::tensor::Cpu {} + impl NonCudnnCuda for crate::tensor::Cpu {} + impl Device> for crate::tensor::Cpu {} + impl NonCudnnCuda> for crate::tensor::Cpu {} +} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "cuda", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} -#[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda_f16 { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device> for crate::tensor::Cuda {} + impl NonCudnnCuda> for crate::tensor::Cuda {} +} #[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} +} // TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? // #[cfg(all(feature = "webgpu", feature = "f16"))] @@ -140,7 +164,11 @@ impl Device for crate::tensor::Cuda {} // #[cfg(all(feature = "webgpu", feature = "f16"))] // impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +mod webgpu { + use super::*; + impl Device for crate::tensor::Webgpu {} + impl NonCudnnCuda for crate::tensor::Webgpu {} +} // TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? // #[cfg(feature = "webgpu")] diff --git a/dfdx/examples/12-mnist.rs b/dfdx/examples/12-mnist.rs index 705d14c8..00d43452 100644 --- a/dfdx/examples/12-mnist.rs +++ b/dfdx/examples/12-mnist.rs @@ -62,9 +62,6 @@ type Mlp = ( const BATCH_SIZE: usize = 32; fn main() { - // ftz substantially improves performance - dfdx::flush_denormals_to_zero(); - let mnist_path = std::env::args() .nth(1) .unwrap_or_else(|| "./datasets/MNIST/raw".to_string()); diff --git a/dfdx/src/nn/layers/layer_rms_norm1d.rs b/dfdx/src/nn/layers/layer_rms_norm1d.rs new file mode 100644 index 00000000..a62fffb9 --- /dev/null +++ b/dfdx/src/nn/layers/layer_rms_norm1d.rs @@ -0,0 +1,169 @@ +use crate::prelude::*; + +/// Implements RMS layer normalization as described in [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467). +/// +/// This calls [normalize_rms()] on the last axis of the input to normalize to unit std dev, and then does an element-wise +/// affine transform using learnable parameters. +/// +/// Epsilon is passed to [normalize_rms()] and added to the variance to ensure big enough numbers. It defaults to `1e-5`. +/// +/// Generics: +/// - `M` The size of the affine transform tensors. +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// # use dfdx::*; +/// # let dev: Cpu = Default::default(); +/// type Model = LayerRMSNorm1DConstConfig<5>; +/// let model = dev.build_module::(Model::default()); +/// let _: Tensor, f32, _> = model.forward(dev.zeros::>()); +/// ``` +#[derive(Default, Clone, Copy, Debug)] +#[repr(transparent)] +pub struct LayerRMSNorm1DConfig(pub M); + +/// Compile time sugar alias around [LayerRMSNorm1DConfig] +pub type LayerRMSNorm1DConstConfig = LayerRMSNorm1DConfig>; + +impl> BuildOnDevice for LayerRMSNorm1DConfig { + type Built = LayerRMSNorm1D; + fn try_build_on_device(&self, device: &D) -> Result { + Ok(LayerRMSNorm1D { + gamma: device.try_ones_like(&(self.0,))?, + beta: device.try_zeros_like(&(self.0,))?, + epsilon: 1e-5, + }) + } +} + +/// See [LayerRMSNorm1DConfig] +#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] +pub struct LayerRMSNorm1D> { + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub gamma: Tensor<(M,), Elem, Dev>, + #[param] + #[cfg_attr(feature = "safetensors", serialize)] + pub beta: Tensor<(M,), Elem, Dev>, + #[cfg_attr(feature = "safetensors", serialize)] + pub epsilon: f64, +} + +impl> ResetParams for LayerRMSNorm1D { + fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error> { + self.gamma.try_fill_with_ones()?; + self.beta.try_fill_with_zeros()?; + Ok(()) + } +} + +impl, T: Tape> Module> + for LayerRMSNorm1D +{ + type Output = Tensor<(M,), E, D, T>; + fn try_forward(&self, x: Tensor<(M,), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().try_mul(x)?; + self.beta.retaped::().try_add(x) + } +} + +impl, T: Tape> Module> + for LayerRMSNorm1D +{ + type Output = Tensor<(Batch, M), E, D, T>; + fn try_forward(&self, x: Tensor<(Batch, M), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().broadcast_like(&x).try_mul(x)?; + self.beta.retaped::().broadcast_like(&x).try_add(x) + } +} + +impl, T: Tape> + Module> for LayerRMSNorm1D +{ + type Output = Tensor<(Batch, Seq, M), E, D, T>; + fn try_forward(&self, x: Tensor<(Batch, Seq, M), E, D, T>) -> Result { + let x = x.try_normalize_rms::>(self.epsilon)?; + let x = self.gamma.retaped::().broadcast_like(&x).try_mul(x)?; + self.beta.retaped::().broadcast_like(&x).try_add(x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::*; + + #[test] + fn test_layer_rms_norm_reset() { + let dev: TestDevice = Default::default(); + + let mut m = dev.build_module::(>::default()); + assert_close_to_literal!(m.gamma, [1.0; 5]); + assert_close_to_literal!(m.beta, [0.0; 5]); + + m.gamma = dev.sample_normal(); + m.beta = dev.sample_normal(); + + assert_ne!(m.gamma.array(), [TestDtype::ONE; 5]); + assert_ne!(m.beta.array(), [TestDtype::default(); 5]); + + m.reset_params(); + + assert_close_to_literal!(m.gamma, [1.0; 5]); + assert_close_to_literal!(m.beta, [0.0; 5]); + } + + #[test] + fn test_layer_rms_norm_1d_forward() { + let dev: TestDevice = Default::default(); + let mut m = dev.build_module::(>::default()); + let x = dev.sample_normal::>(); + let r = m.forward_mut(x.leaky_trace()); + assert_close_to_literal!( + r, + [0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052] + ); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&m.gamma), + [0.10726271, 0.12916003, -0.3666012, 0.024579724, -0.19186105] + ); + assert_close_to_literal!(g.get(&m.beta), [0.2; 5]); + } + + #[test] + fn test_layer_rms_norm_2d_forward() { + let dev: TestDevice = Default::default(); + let m = dev.build_module::(>::default()); + let x = dev.sample_normal::>(); + let r = m.forward(x.leaky_trace()); + assert_close_to_literal!( + r, + [ + [0.53631353, 0.6458002, -1.8330059, 0.12289862, -0.9593052], + [1.0418473, -1.199064, 0.49583954, 0.5000605, 1.4074267], + [0.90727454, -1.6644237, -0.5176145, 1.0127299, -0.33612955] + ] + ); + let g = r.mean().backward(); + assert_close_to_literal!( + g.get(&m.gamma), + [ + 0.16569571, + -0.14784585, + -0.123652056, + 0.10904594, + 0.0074661337 + ] + ); + assert_close_to_literal!(g.get(&m.beta), [0.2; 5]); + } +} + +// Implementation references: +// - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328 +// - https://github.com/kroggen/mamba.c/blob/7387f49e352f86a0c22041c0f66fd2a40b58a207/mamba.c#L222 diff --git a/dfdx/src/nn/layers/mod.rs b/dfdx/src/nn/layers/mod.rs index 828b1e97..062b9f08 100644 --- a/dfdx/src/nn/layers/mod.rs +++ b/dfdx/src/nn/layers/mod.rs @@ -20,6 +20,7 @@ mod gelu; mod generalized_add; mod generalized_mul; mod layer_norm1d; +mod layer_rms_norm1d; mod leaky_relu; mod linear; mod ln; @@ -73,6 +74,7 @@ pub use gelu::{AccurateGeLU, FastGeLU}; pub use generalized_add::GeneralizedAdd; pub use generalized_mul::GeneralizedMul; pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig}; +pub use layer_rms_norm1d::{LayerRMSNorm1D, LayerRMSNorm1DConfig, LayerRMSNorm1DConstConfig}; pub use leaky_relu::LeakyReLU; pub use linear::{Linear, LinearConfig, LinearConstConfig}; pub use ln::Ln;