diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 20c55da2..8ba56589 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -113,6 +113,135 @@ pub trait ZeroGrads> { } } +/// Something that can view or mutate a [Gradients] object. +pub trait WithGrads> { + /// View the gradient values for each parameter. + fn grads_element_view(&self, grads: &Gradients, f: F) { + self.try_grads_element_view(grads, f).unwrap() + } + /// View the gradient values for each parameter. + fn try_grads_element_view( + &self, + grads: &Gradients, + f: F, + ) -> Result<(), Error>; + /// View the gradient values for each tensor (unique id). + fn grads_view(&self, grads: &Gradients, f: F) { + self.try_grads_view(grads, f).unwrap() + } + /// View the gradient values for each tensor (unique id). + fn try_grads_view(&self, grads: &Gradients, f: F) -> Result<(), Error>; + /// Mutate the gradient values for each parameter. + fn grads_element_map E>(&self, grads: &mut Gradients, f: F) { + self.try_grads_element_map(grads, f).unwrap() + } + /// Mutate the gradient values for each parameter. + fn try_grads_element_map E>( + &self, + grads: &mut Gradients, + f: F, + ) -> Result<(), crate::tensor::Error>; + /// Mutate the gradient values for each tensor (unique id). + fn grads_map) -> Option>>(&self, grads: &mut Gradients, f: F) { + self.try_grads_map(grads, f).unwrap() + } + /// Mutate the gradient values for each tensor (unique id). + fn try_grads_map) -> Option>>( + &self, + grads: &mut Gradients, + f: F, + ) -> Result<(), crate::tensor::Error>; + /// Changes the gradient values for each parameter to be between `min` and `max`. + /// + /// Note that this may change the "direction" of your gradients. + fn grads_clamp(&self, grads: &mut Gradients, min: E, max: E) + where + E: std::cmp::PartialOrd + Clone, + { + self.try_grads_clamp(grads, min, max).unwrap() + } + /// Changes the gradient values for each parameter to be between `min` and `max`. + /// + /// Note that this may change the "direction" of your gradients. + fn try_grads_clamp(&self, grads: &mut Gradients, min: E, max: E) -> Result<(), Error> + where + E: std::cmp::PartialOrd + Clone, + { + self.try_grads_element_map(grads, |e| { + if e < min { + min + } else if e > max { + max + } else { + e + } + }) + } + /// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`. + /// + /// Note that this may change the "direction" of your gradients. + fn grads_clip_value(&self, grads: &mut Gradients, threshold: E) + where + E: std::cmp::PartialOrd + std::ops::Neg + Clone, + { + self.try_grads_clip_value(grads, threshold).unwrap() + } + /// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`. + /// + /// Note that this may change the "direction" of your gradients. + fn try_grads_clip_value(&self, grads: &mut Gradients, threshold: E) -> Result<(), Error> + where + E: std::cmp::PartialOrd + std::ops::Neg + Clone, + { + self.try_grads_clamp(grads, -threshold, threshold) + } + /// Accumulates into `acc` the squared value for the gradients. + /// + /// After the accumulation, taking the sqrt of `acc` results in the gradients norm. + fn grads_norm_squared(&self, grads: &Gradients, acc: &mut E) + where + E: num_traits::Zero + std::ops::Mul + num_traits::Float, + { + self.try_grads_norm_squared(grads, acc).unwrap() + } + /// Accumulates into `acc` the squared value for the gradients. + /// + /// After the accumulation, taking the sqrt of `acc` results in the gradients norm. + fn try_grads_norm_squared(&self, grads: &Gradients, acc: &mut E) -> Result<(), Error> + where + E: std::ops::Mul + num_traits::Float, + { + self.try_grads_element_view(grads, |e| *acc += *e * *e) + } + /// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`. + /// + /// Note that this doesn't change the "direction" of your gradients. + fn grads_clip_norm(&self, grads: &mut Gradients, norm: E, norm_threshold: E) + where + E: Clone + std::cmp::PartialOrd + std::ops::Mul + std::ops::Div, + { + self.try_grads_clip_norm(grads, norm, norm_threshold) + .unwrap() + } + /// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`. + /// + /// Note that this doesn't change the "direction" of your gradients. + fn try_grads_clip_norm( + &self, + grads: &mut Gradients, + norm: E, + norm_threshold: E, + ) -> Result<(), Error> + where + E: Clone + std::cmp::PartialOrd + std::ops::Mul + std::ops::Div, + { + if norm > norm_threshold { + self.try_grads_element_map(grads, |e| norm_threshold * e / norm)? + } + Ok(()) + } +} + #[cfg(feature = "safetensors")] /// Something that can be saved to a .safetensors file. pub trait SaveSafeTensors { diff --git a/dfdx-core/src/nn_traits/tuples.rs b/dfdx-core/src/nn_traits/tuples.rs index 97e8c7de..74cdf405 100644 --- a/dfdx-core/src/nn_traits/tuples.rs +++ b/dfdx-core/src/nn_traits/tuples.rs @@ -67,6 +67,25 @@ macro_rules! tuple_impls { } } + impl, Elem: Dtype, $($name: crate::nn_traits::WithGrads),+> crate::nn_traits::WithGrads for ($($name,)+) { + fn try_grads_element_view(&self, grads: &crate::prelude::Gradients, mut f: F) -> Result<(), Error> { + $(self.$idx.try_grads_element_view(grads, &mut f)?;)+ + Ok(()) + } + fn try_grads_view(&self, grads: &crate::prelude::Gradients, mut f: F) -> Result<(), Error> { + $(self.$idx.try_grads_view(grads, &mut f)?;)+ + Ok(()) + } + fn try_grads_element_map Elem>(&self, grads: &mut crate::prelude::Gradients, mut f: F) -> Result<(), Error> { + $(self.$idx.try_grads_element_map(grads, &mut f)?;)+ + Ok(()) + } + fn try_grads_map) -> Option>>(&self, grads: &mut crate::prelude::Gradients, mut f: F) -> Result<(), Error> { + $(self.$idx.try_grads_map(grads, &mut f)?;)+ + Ok(()) + } + } + /*This macro expands like this for a 4-tuple: impl< diff --git a/dfdx-core/src/nn_traits/vecs.rs b/dfdx-core/src/nn_traits/vecs.rs index 803a07d8..983b4613 100644 --- a/dfdx-core/src/nn_traits/vecs.rs +++ b/dfdx-core/src/nn_traits/vecs.rs @@ -58,6 +58,51 @@ impl, T: crate::nn_traits::ZeroGrads> crate::nn_tra } } +impl, T: crate::nn_traits::WithGrads> crate::nn_traits::WithGrads + for Vec +{ + fn try_grads_element_view( + &self, + grads: &crate::tensor::Gradients, + mut f: F, + ) -> Result<(), crate::tensor::Error> { + for m_i in self.iter() { + m_i.try_grads_element_view(grads, &mut f)?; + } + Ok(()) + } + fn try_grads_view( + &self, + grads: &crate::tensor::Gradients, + mut f: F, + ) -> Result<(), crate::tensor::Error> { + for m_i in self.iter() { + m_i.try_grads_view(grads, &mut f)?; + } + Ok(()) + } + fn try_grads_element_map E>( + &self, + grads: &mut crate::tensor::Gradients, + mut f: F, + ) -> Result<(), crate::tensor::Error> { + for m_i in self.iter() { + m_i.try_grads_element_map(grads, &mut f)?; + } + Ok(()) + } + fn try_grads_map) -> Option>>( + &self, + grads: &mut crate::tensor::Gradients, + mut f: F, + ) -> Result<(), crate::tensor::Error> { + for m_i in self.iter() { + m_i.try_grads_map(grads, &mut f)?; + } + Ok(()) + } +} + #[cfg(feature = "safetensors")] impl crate::nn_traits::SaveSafeTensors for Vec { fn write_safetensors( diff --git a/dfdx-core/src/tensor/cpu/allocate.rs b/dfdx-core/src/tensor/cpu/allocate.rs index cf93623b..e85ca7ba 100644 --- a/dfdx-core/src/tensor/cpu/allocate.rs +++ b/dfdx-core/src/tensor/cpu/allocate.rs @@ -78,6 +78,48 @@ impl ZeroFillStorage for Cpu { } } +impl WithStorage for Cpu { + /// View the values by each element (in-place). + fn try_element_view(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> { + for e in storage.iter() { + f(e); + } + Ok(()) + } + /// View the values by a [Vec] (in-place). + fn try_view(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> { + f(storage.data.as_slice()); + Ok(()) + } + /// Mutates the values by each element (in-place). + fn try_element_map E>( + &self, + storage: &mut Self::Vec, + mut f: F, + ) -> Result<(), Error> { + for e in storage.iter_mut() { + let fe = f(*e); + *e = fe; + } + Ok(()) + } + /// Mutates a clone of the values (not in-place). + /// + /// If `Some` is returned, replaces the changed values back into the object. + /// Otherwise if `None` is returned, the changed values are discarded and the object stays intact. + fn try_map) -> Option>>( + &self, + storage: &mut Self::Vec, + mut f: F, + ) -> Result<(), Error> { + let storage_copy = storage.data.clone(); + if let Some(fstorage) = f(storage_copy) { + storage.data.copy_from_slice(&fstorage); + } + Ok(()) + } +} + impl OnesTensor for Cpu { fn try_ones_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); diff --git a/dfdx-core/src/tensor/cuda/allocate.rs b/dfdx-core/src/tensor/cuda/allocate.rs index aa489f9a..645de411 100644 --- a/dfdx-core/src/tensor/cuda/allocate.rs +++ b/dfdx-core/src/tensor/cuda/allocate.rs @@ -60,6 +60,53 @@ impl ZeroFillStorage for Cuda { } } +impl WithStorage for Cuda { + /// View a copy of the values by each element (not in-place). + fn try_element_view(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> { + let v = self.dev.dtoh_sync_copy(storage)?; + for e in v.iter() { + f(e); + } + Ok(()) + } + /// View a copy of the values by a [Vec] (not in-place). + fn try_view(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> { + let v = self.dev.dtoh_sync_copy(storage)?; + f(v.as_slice()); + Ok(()) + } + /// Mutates a copy of the values by each element (not in-place). + /// Then the values in Cuda memory are replaced by the changed values. + fn try_element_map E>( + &self, + storage: &mut Self::Vec, + mut f: F, + ) -> Result<(), Error> { + let mut v = self.dev.dtoh_sync_copy(storage)?; + for e in v.iter_mut() { + let fe = (&mut f)(*e); + *e = fe; + } + self.dev.htod_copy_into(v, storage)?; + Ok(()) + } + /// Mutates a copy of the values (not in-place). + /// + /// If `Some` is returned, the values in Cuda memory are replaced by the changed values. + /// Otherwise if `None` is returned, the values in Cuda memory are left intact. + fn try_map) -> Option>>( + &self, + storage: &mut Self::Vec, + mut f: F, + ) -> Result<(), Error> { + let v = self.dev.dtoh_sync_copy(storage)?; + if let Some(fv) = (&mut f)(v) { + self.dev.htod_copy_into(fv, storage)?; + } + Ok(()) + } +} + impl OnesTensor for Cuda where Cpu: OnesTensor, diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index d24e2e32..e87f2bdf 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -86,14 +86,14 @@ impl> Gradients { /// Returns a mutable reference to the data associated with `t`. /// /// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug. - pub(crate) fn get_mut(&mut self, t: &impl Tensorlike) -> &mut D::Vec { + pub fn get_mut(&mut self, t: &impl Tensorlike) -> &mut D::Vec { self.gradient_by_id.get_mut(&t.id()).unwrap() } /// Returns an immutable reference to the data associated with `t`. /// /// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug. - pub(crate) fn get_ref(&mut self, t: &impl Tensorlike) -> &D::Vec { + pub fn get_ref(&self, t: &impl Tensorlike) -> &D::Vec { self.gradient_by_id.get(&t.id()).unwrap() } diff --git a/dfdx-core/src/tensor/mod.rs b/dfdx-core/src/tensor/mod.rs index acc4074a..0fb5ab34 100644 --- a/dfdx-core/src/tensor/mod.rs +++ b/dfdx-core/src/tensor/mod.rs @@ -160,7 +160,7 @@ mod tensor_impls; pub use error::Error; pub(crate) use ghost::GhostTensor; -pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage}; +pub(crate) use storage_traits::{OneFillStorage, WithStorage, ZeroFillStorage}; pub use tensorlike::Tensorlike; pub use cpu::Cpu; diff --git a/dfdx-core/src/tensor/storage_traits.rs b/dfdx-core/src/tensor/storage_traits.rs index 9578947d..35c9f735 100644 --- a/dfdx-core/src/tensor/storage_traits.rs +++ b/dfdx-core/src/tensor/storage_traits.rs @@ -170,6 +170,26 @@ pub trait ZeroFillStorage: Storage { fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error>; } +/// View or mutate a [Storage::Vec] object. +pub trait WithStorage: Storage { + /// View the values by each element. + fn try_element_view(&self, storage: &Self::Vec, f: F) -> Result<(), Error>; + /// View the values by a [Vec]. + fn try_view(&self, storage: &Self::Vec, f: F) -> Result<(), Error>; + /// Mutates the values by each element. + fn try_element_map E>(&self, storage: &mut Self::Vec, f: F) + -> Result<(), Error>; + /// Mutates a clone of the values. + /// + /// If `Some` is returned, replaces the changed values back into the object. + /// Otherwise if `None` is returned, the changed values are discarded and the object stays intact. + fn try_map) -> Option>>( + &self, + storage: &mut Self::Vec, + f: F, + ) -> Result<(), Error>; +} + /// Construct tensors filled with ones. pub trait OnesTensor: Storage { /// Creates a tensor filled with ones. @@ -567,3 +587,62 @@ impl> TensorFrom<(Vec, S), S, E> for D { self.try_tensor_from_vec(src, shape) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{prelude::SumTo, tensor::*, tensor_ops::Backward, tests::*}; + use core::ops::Mul; + + #[test] + fn test_map_grads() { + let dev: TestDevice = Default::default(); + let x1 = dev.tensor([1., 1., 1., 1., 1., 1.]).to_dtype::(); + let x2 = dev + .tensor([-3., -2., -1., 1., 2., 3.]) + .to_dtype::(); + let loss = x1.leaky_trace().mul(x2).try_sum().unwrap(); + let mut grads = loss.backward(); + let grads_x1 = grads.get_mut(&x1); + + let mut acc = 0.; + let map_element = |e| { + acc += 1.; + e + acc + }; + let map_vec = |v: Vec<_>| Some(v.into_iter().map(|e| e * 0.5).collect()); + + let (g1, g2, g3); + let r1 = vec![-3., -2., -1., 1., 2., 3.]; + let r2 = vec![-2., 0., 2., 5., 7., 9.]; + let r3 = vec![-1., 0., 1., 2.5, 3.5, 4.5]; + + #[cfg(feature = "cuda")] + { + g1 = dev.dev.dtoh_sync_copy(grads_x1).unwrap(); + dev.try_element_map(grads_x1, map_element).unwrap(); + g2 = dev.dev.dtoh_sync_copy(grads_x1).unwrap(); + dev.try_map(grads_x1, map_vec).unwrap(); + g3 = dev.dev.dtoh_sync_copy(grads_x1).unwrap(); + }; + #[cfg(feature = "webgpu")] + { + g1 = todo!(); + dev.try_element_map(grads_x1, map_element).unwrap(); + g2 = todo!(); + dev.try_map(grads_x1, map_vec).unwrap(); + g3 = todo!(); + }; + #[cfg(not(any(feature = "cuda", feature = "webgpu")))] + { + g1 = grads_x1.data.clone(); + dev.try_element_map(grads_x1, map_element).unwrap(); + g2 = grads_x1.data.clone(); + dev.try_map(grads_x1, map_vec).unwrap(); + g3 = grads_x1.data.clone(); + }; + assert_eq!(g1, r1); + assert_eq!(g2, r2); + assert_eq!(g3, r3); + } +} diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 4e4a2692..410c8404 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -71,6 +71,29 @@ impl ZeroFillStorage for Webgpu { } } +impl WithStorage for Webgpu { + fn try_element_view(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> { + todo!() + } + fn try_view(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> { + todo!() + } + fn try_element_map E>( + &self, + storage: &mut Self::Vec, + mut f: F, + ) -> Result<(), Error> { + todo!() + } + fn try_map) -> Option>>( + &self, + storage: &mut Self::Vec, + mut f: F, + ) -> Result<(), Error> { + todo!() + } +} + impl OnesTensor for Webgpu { fn try_ones_like(&self, src: &S) -> Result, Error> { let shape = *src.shape(); diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 91f87cf6..5428ed42 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -32,6 +32,7 @@ pub trait Device: + crate::tensor::SampleTensor + crate::tensor::OneFillStorage + crate::tensor::ZeroFillStorage + + crate::tensor::WithStorage // broadcast & reduces + super::super::sum_to::SumKernel diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 4eca0d82..cb19f775 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -13,8 +13,9 @@ macro_rules! has_attr { /// 2. [dfdx::nn_traits::ResetParams] /// 3. [dfdx::nn_traits::UpdateParams] /// 4. [dfdx::nn_traits::ZeroGrads] -/// 5. [dfdx::nn_traits::SaveSafeTensors] -/// 6. [dfdx::nn_traits::LoadSafeTensors] +/// 5. [dfdx::nn_traits::WithGrads] +/// 6. [dfdx::nn_traits::SaveSafeTensors] +/// 7. [dfdx::nn_traits::LoadSafeTensors] /// /// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx::nn_traits::BuildOnDevice]. /// @@ -175,11 +176,11 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream quote!() }; quote! { - #[derive(Clone, Debug, ::dfdx::ResetParams, ::dfdx::UpdateParams, ::dfdx::ZeroGrads, #safetensors_derive)] + #[derive(Clone, Debug, ::dfdx::ResetParams, ::dfdx::UpdateParams, ::dfdx::ZeroGrads, ::dfdx::WithGrads, #safetensors_derive)] pub struct #built_name #built_impl #built_where #fields } } else { - // there are no fields to build - we still have to derive ResetParams/UpdateParams/ZeroGrads, but since + // there are no fields to build - we still have to derive ResetParams/UpdateParams/ZeroGrads/WithGrads, but since // there aren't any fields, they will just be passthrough impls let mut build_generics = built_generics.clone(); if !has_fields_to_build { @@ -242,6 +243,21 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream Ok(()) } } + + impl #build_impl ::dfdx::nn_traits::WithGrads for #builder_name #built_ty #built_where { + fn try_grads_element_view<__F: FnMut(&Elem)>(&self, _grads: & ::dfdx::tensor::Gradients, _f: __F) -> Result<(), ::dfdx::tensor::Error> { + Ok(()) + } + fn try_grads_view<__F: FnMut(&[Elem])>(&self, _grads: & ::dfdx::tensor::Gradients, _f: __F) -> Result<(), ::dfdx::tensor::Error> { + Ok(()) + } + fn try_grads_element_map<__F: FnMut(Elem) -> Elem>(&self, _grads: &mut ::dfdx::tensor::Gradients, _f: __F) -> Result<(), ::dfdx::tensor::Error> { + Ok(()) + } + fn try_grads_map<__F: FnMut(Vec) -> Option>>(&self, _grads: &mut ::dfdx::tensor::Gradients, _f: __F) -> Result<(), ::dfdx::tensor::Error> { + Ok(()) + } + } } }; (built_name, def) @@ -431,7 +447,7 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream { }; quote! { - #[derive(Clone, Debug, ::dfdx::ResetParams, ::dfdx::UpdateParams, ::dfdx::ZeroGrads, #safetensors_derive)] + #[derive(Clone, Debug, ::dfdx::ResetParams, ::dfdx::UpdateParams, ::dfdx::ZeroGrads, ::dfdx::WithGrads, #safetensors_derive)] pub struct #built_name #built_impl #built_where { #fields } @@ -831,6 +847,113 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { }) } +#[proc_macro_derive(WithGrads, attributes(param, module))] +pub fn with_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let mut input = parse_macro_input!(input as DeriveInput); + + let name = input.ident; + + let mut custom_generics = input.generics.clone(); + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Elem"), + ) { + custom_generics + .params + .push(parse_quote!(Elem: ::dfdx::prelude::Dtype)); + } + + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Dev"), + ) { + custom_generics + .params + .push(parse_quote!(Dev: ::dfdx::prelude::Device)); + } + + let where_clause = input.generics.make_where_clause(); + let mut grads_element_view = proc_macro2::TokenStream::default(); + let mut grads_view = proc_macro2::TokenStream::default(); + let mut grads_element_map = proc_macro2::TokenStream::default(); + let mut grads_map = proc_macro2::TokenStream::default(); + match &input.data { + Data::Struct(ref obj) => { + match obj.fields { + Fields::Named(ref fields) => { + for f in fields.named.iter() { + let name = &f.ident; + let ty = &f.ty; + if has_attr!(f, "module") { + where_clause + .predicates + .push(parse_quote!(#ty: ::dfdx::nn_traits::WithGrads)); + grads_element_view.extend(quote_spanned!(f.span()=>self.#name.try_grads_element_view(grads, &mut f)?;)); + grads_view.extend( + quote_spanned!(f.span()=>self.#name.try_grads_view(grads, &mut f)?;), + ); + grads_element_map .extend( quote_spanned!(f.span()=>self.#name.try_grads_element_map(grads, &mut f)?;)); + grads_map.extend( + quote_spanned!(f.span()=>self.#name.try_grads_map(grads, &mut f)?;), + ); + } else if has_attr!(f, "param") { + grads_element_view .extend( quote_spanned!(f.span()=>self.#name.device().try_element_view(grads.get_ref(&self.#name), &mut f)?;)); + grads_view .extend( quote_spanned!(f.span()=>self.#name.device().try_view(grads.get_ref(&self.#name), &mut f)?;)); + grads_element_map .extend( quote_spanned!(f.span()=>self.#name.device().try_element_map(grads.get_mut(&self.#name), &mut f)?;)); + grads_map .extend( quote_spanned!(f.span()=>self.#name.device().try_map(grads.get_mut(&self.#name), &mut f)?;)); + } + } + } + Fields::Unnamed(ref fields) => { + for (i, f) in fields.unnamed.iter().enumerate() { + let index = Index::from(i); + let ty = &f.ty; + if has_attr!(f, "module") { + where_clause + .predicates + .push(parse_quote!(#ty: ::dfdx::nn_traits::WithGrads)); + grads_element_view.extend(quote_spanned!(f.span()=>self.#index.try_grads_element_view(grads, &mut f)?;)); + grads_view.extend(quote_spanned!(f.span()=>self.#index.try_grads_view(grads, &mut f)?;)); + grads_element_map.extend(quote_spanned!(f.span()=>self.#index.try_grads_element_map(grads, &mut f)?;)); + grads_map.extend(quote_spanned!(f.span()=>self.#index.try_grads_map(grads, &mut f)?;)); + } else if has_attr!(f, "param") { + grads_element_view.extend(quote_spanned!(f.span()=>self.#index.device().try_element_view(grads.get_ref(&self.#index), &mut f)?)); + grads_view.extend(quote_spanned!(f.span()=>self.#index.device().try_view(grads.get_ref(&self.#index), &mut f)?)); + grads_element_map.extend(quote_spanned!(f.span()=>self.#index.device().try_element_map(grads.get_mut(&self.#index), &mut f)?)); + grads_map.extend(quote_spanned!(f.span()=>self.#index.device().try_map(grads.get_mut(&self.#index), &mut f)?)); + } + } + } + Fields::Unit => {} + } + } + Data::Enum(_) => unimplemented!("WithGrads not implemented for enums."), + Data::Union(_) => unimplemented!("WithGrads not implemented for unions."), + }; + + let (impl_generics, _, _) = custom_generics.split_for_impl(); + let (_, ty_generics, where_clause) = input.generics.split_for_impl(); + + proc_macro::TokenStream::from(quote! { + impl #impl_generics ::dfdx::nn_traits::WithGrads for #name #ty_generics #where_clause { + fn try_grads_element_view<__F: FnMut(&Elem)>(&self, grads: & ::dfdx::prelude::Gradients, mut f: __F) -> Result<(), ::dfdx::tensor::Error> { + #grads_element_view + Ok(()) + } + fn try_grads_view<__F: FnMut(&[Elem])>(&self, grads: & ::dfdx::prelude::Gradients, mut f: __F) -> Result<(), ::dfdx::tensor::Error> { + #grads_view + Ok(()) + } + fn try_grads_element_map<__F: FnMut(Elem) -> Elem>(&self, grads: &mut ::dfdx::prelude::Gradients, mut f: __F) -> Result<(), ::dfdx::tensor::Error> { + #grads_element_map + Ok(()) + } + fn try_grads_map<__F: FnMut(Vec) -> Option>>(&self, grads: &mut ::dfdx::prelude::Gradients, mut f: __F) -> Result<(), ::dfdx::tensor::Error> { + #grads_map + Ok(()) + } + } + }) +} + #[proc_macro_derive(SaveSafeTensors, attributes(serialize))] pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut input = parse_macro_input!(input as DeriveInput); diff --git a/dfdx/src/lib.rs b/dfdx/src/lib.rs index 235ca50f..6f59f098 100644 --- a/dfdx/src/lib.rs +++ b/dfdx/src/lib.rs @@ -262,7 +262,7 @@ pub use dfdx_core::*; #[cfg(feature = "safetensors")] pub use safetensors; -pub use dfdx_derives::{CustomModule, ResetParams, Sequential, UpdateParams, ZeroGrads}; +pub use dfdx_derives::{CustomModule, ResetParams, Sequential, UpdateParams, WithGrads, ZeroGrads}; #[cfg(feature = "safetensors")] pub use dfdx_derives::{LoadSafeTensors, SaveSafeTensors}; diff --git a/dfdx/src/nn/layers/add_into.rs b/dfdx/src/nn/layers/add_into.rs index 5b57dd56..fe3b50bc 100644 --- a/dfdx/src/nn/layers/add_into.rs +++ b/dfdx/src/nn/layers/add_into.rs @@ -19,7 +19,7 @@ use crate::prelude::*; /// let b: Tensor, f32, _> = dev.zeros(); /// let _: Tensor, f32, _> = model.forward((a, b)); /// ``` -#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)] +#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, WithGrads, UpdateParams)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct AddInto( diff --git a/dfdx/src/nn/layers/batch_norm1d.rs b/dfdx/src/nn/layers/batch_norm1d.rs index ce8ff93c..f805a46d 100644 --- a/dfdx/src/nn/layers/batch_norm1d.rs +++ b/dfdx/src/nn/layers/batch_norm1d.rs @@ -55,7 +55,7 @@ impl> BuildOnDevice for BatchNorm1DConfig> { /// Scale for affine transform. Defaults to 1.0 diff --git a/dfdx/src/nn/layers/batch_norm2d.rs b/dfdx/src/nn/layers/batch_norm2d.rs index bd8fc013..bea0e553 100644 --- a/dfdx/src/nn/layers/batch_norm2d.rs +++ b/dfdx/src/nn/layers/batch_norm2d.rs @@ -57,7 +57,7 @@ impl> crate::nn::BuildOnDevice for BatchNor } /// See [BatchNorm2DConfig] -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct BatchNorm2D> { #[param] diff --git a/dfdx/src/nn/layers/bias1d.rs b/dfdx/src/nn/layers/bias1d.rs index c904df1a..f29c969e 100644 --- a/dfdx/src/nn/layers/bias1d.rs +++ b/dfdx/src/nn/layers/bias1d.rs @@ -36,7 +36,7 @@ impl> BuildOnDevice for Bias1DConfig { } /// See [Bias1DConfig] -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Bias1D> { #[param] diff --git a/dfdx/src/nn/layers/bias2d.rs b/dfdx/src/nn/layers/bias2d.rs index 397e0707..174e553b 100644 --- a/dfdx/src/nn/layers/bias2d.rs +++ b/dfdx/src/nn/layers/bias2d.rs @@ -36,7 +36,7 @@ impl> BuildOnDevice for Bias2DConfig { } /// See [Bias2DConfig] -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Bias2D> { #[param] diff --git a/dfdx/src/nn/layers/conv1d.rs b/dfdx/src/nn/layers/conv1d.rs index 5241b912..23f0ef4d 100644 --- a/dfdx/src/nn/layers/conv1d.rs +++ b/dfdx/src/nn/layers/conv1d.rs @@ -78,7 +78,7 @@ where } /// The module built with [Conv1DConfig]. See [Conv1DConfig] for usage. -#[derive(Debug, Clone, UpdateParams, ZeroGrads)] +#[derive(Debug, Clone, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Conv1D where diff --git a/dfdx/src/nn/layers/conv2d.rs b/dfdx/src/nn/layers/conv2d.rs index c88ea821..820894f3 100644 --- a/dfdx/src/nn/layers/conv2d.rs +++ b/dfdx/src/nn/layers/conv2d.rs @@ -99,7 +99,7 @@ where } /// The module built with [Conv2DConfig]. See [Conv2DConfig] for usage. -#[derive(Debug, Clone, UpdateParams, ZeroGrads)] +#[derive(Debug, Clone, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Conv2D where diff --git a/dfdx/src/nn/layers/conv_trans2d.rs b/dfdx/src/nn/layers/conv_trans2d.rs index b7683676..24a5f611 100644 --- a/dfdx/src/nn/layers/conv_trans2d.rs +++ b/dfdx/src/nn/layers/conv_trans2d.rs @@ -77,7 +77,7 @@ where } /// See [ConvTrans2DConfig]. -#[derive(Debug, Clone, UpdateParams, ZeroGrads)] +#[derive(Debug, Clone, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct ConvTrans2D where diff --git a/dfdx/src/nn/layers/embedding.rs b/dfdx/src/nn/layers/embedding.rs index 6c7971e9..98b843cb 100644 --- a/dfdx/src/nn/layers/embedding.rs +++ b/dfdx/src/nn/layers/embedding.rs @@ -51,7 +51,7 @@ impl> BuildOnDevice for EmbeddingCo } /// See [EmbeddingConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Embedding> { #[param] diff --git a/dfdx/src/nn/layers/generalized_add.rs b/dfdx/src/nn/layers/generalized_add.rs index 3dc4708b..08d2d8aa 100644 --- a/dfdx/src/nn/layers/generalized_add.rs +++ b/dfdx/src/nn/layers/generalized_add.rs @@ -18,7 +18,7 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [4.0, 1.0, 0.0, 2.0, 6.0]); /// ``` -#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, WithGrads, UpdateParams)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct GeneralizedAdd { #[module] diff --git a/dfdx/src/nn/layers/generalized_mul.rs b/dfdx/src/nn/layers/generalized_mul.rs index 64562024..197d5b4c 100644 --- a/dfdx/src/nn/layers/generalized_mul.rs +++ b/dfdx/src/nn/layers/generalized_mul.rs @@ -17,7 +17,7 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [0.0, 0.0, 0.0, 1.0, 8.0]); /// ``` -#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, WithGrads, UpdateParams)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct GeneralizedMul { #[module] diff --git a/dfdx/src/nn/layers/layer_norm1d.rs b/dfdx/src/nn/layers/layer_norm1d.rs index 363db245..046e79d5 100644 --- a/dfdx/src/nn/layers/layer_norm1d.rs +++ b/dfdx/src/nn/layers/layer_norm1d.rs @@ -38,7 +38,7 @@ impl> BuildOnDevice for LayerNorm1DConfig> { #[param] diff --git a/dfdx/src/nn/layers/linear.rs b/dfdx/src/nn/layers/linear.rs index 2d8f2e08..13c9f137 100644 --- a/dfdx/src/nn/layers/linear.rs +++ b/dfdx/src/nn/layers/linear.rs @@ -47,7 +47,7 @@ impl> BuildOnDevice for LinearConfi } /// See [LinearConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct Linear> { #[param] diff --git a/dfdx/src/nn/layers/matmul.rs b/dfdx/src/nn/layers/matmul.rs index a2e301e9..3f2a5f2a 100644 --- a/dfdx/src/nn/layers/matmul.rs +++ b/dfdx/src/nn/layers/matmul.rs @@ -36,7 +36,7 @@ impl> BuildOnDevice for MatMulConfi } /// See [MatMulConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct MatMul> { #[param] diff --git a/dfdx/src/nn/layers/prelu.rs b/dfdx/src/nn/layers/prelu.rs index 9eb8f508..7dbbb4b2 100644 --- a/dfdx/src/nn/layers/prelu.rs +++ b/dfdx/src/nn/layers/prelu.rs @@ -19,7 +19,7 @@ impl> BuildOnDevice for PReLUConfig { } /// See [PReLUConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct PReLU> { #[param] diff --git a/dfdx/src/nn/layers/prelu1d.rs b/dfdx/src/nn/layers/prelu1d.rs index fa0a35b9..c6857aa5 100644 --- a/dfdx/src/nn/layers/prelu1d.rs +++ b/dfdx/src/nn/layers/prelu1d.rs @@ -25,7 +25,7 @@ impl> BuildOnDevice for PReLU1DConfig { } /// See [PReLU1DConfig]. -#[derive(Clone, Debug, UpdateParams, ZeroGrads)] +#[derive(Clone, Debug, UpdateParams, ZeroGrads, WithGrads)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] pub struct PReLU1D> { #[param] diff --git a/dfdx/src/nn/layers/residual_add.rs b/dfdx/src/nn/layers/residual_add.rs index 77c7ca97..a72837b5 100644 --- a/dfdx/src/nn/layers/residual_add.rs +++ b/dfdx/src/nn/layers/residual_add.rs @@ -17,7 +17,7 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [-2.0, -1.0, 0.0, 2.0, 4.0]); /// ``` -#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, WithGrads, UpdateParams)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct ResidualAdd( diff --git a/dfdx/src/nn/layers/residual_mul.rs b/dfdx/src/nn/layers/residual_mul.rs index 7a9c9c9d..69f2f7eb 100644 --- a/dfdx/src/nn/layers/residual_mul.rs +++ b/dfdx/src/nn/layers/residual_mul.rs @@ -16,7 +16,7 @@ use crate::prelude::*; /// let y = model.forward(x); /// assert_eq!(y.array(), [0.0, 0.0, 0.0, 1.0, 4.0]); /// ``` -#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, UpdateParams)] +#[derive(Default, Clone, Debug, ResetParams, ZeroGrads, WithGrads, UpdateParams)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct ResidualMul( diff --git a/dfdx/src/nn/layers/split_into.rs b/dfdx/src/nn/layers/split_into.rs index 440cba5c..9cc77d19 100644 --- a/dfdx/src/nn/layers/split_into.rs +++ b/dfdx/src/nn/layers/split_into.rs @@ -21,7 +21,7 @@ use crate::prelude::*; /// let model = dev.build_module::(Model::default()); /// let _: (Tensor, f32, _>, Tensor, f32, _>) = model.forward(dev.zeros::>()); /// ``` -#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)] +#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, WithGrads, UpdateParams)] #[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] #[repr(transparent)] pub struct SplitInto(