From af91dbab159bda26e333e5efc5b2777beb2c9bb7 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Mon, 18 Nov 2024 01:16:50 +0100 Subject: [PATCH] Rename base_to_data to default_base_to_data --- examples/custom_module.rs | 10 ++++++++- src/buffer.rs | 9 ++++---- src/buffer/num.rs | 13 ++++++----- src/cache/borrow_cache.rs | 4 +++- src/devices.rs | 16 +++++++++++--- src/devices/cpu/cpu_device.rs | 14 ++++++++---- src/hooks.rs | 7 +++++- src/modules/autograd/wrapper.rs | 13 +++++++---- src/modules/base.rs | 7 ++++-- src/modules/cached.rs | 38 ++++++++++++++++++++++----------- src/modules/lazy.rs | 2 +- src/modules/lazy/wrapper.rs | 7 ++++-- src/modules/mod.rs | 8 +++++-- src/wrapper.rs | 23 ++++++++++++++------ 14 files changed, 121 insertions(+), 50 deletions(-) diff --git a/examples/custom_module.rs b/examples/custom_module.rs index c2156da1..f9f44ade 100644 --- a/examples/custom_module.rs +++ b/examples/custom_module.rs @@ -39,10 +39,18 @@ impl WrappedData for CustomModule { type Wrap<'a, T: Unit, Base: IsBasePtr> = Mods::Wrap<'a, T, Base>; #[inline] - fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> { self.mods.wrap_in_base(base) } + #[inline] + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base> { + self.mods.wrap_in_base_unbound(base) + } + #[inline] fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>( wrap: &'b Self::Wrap<'a, T, Base>, diff --git a/src/buffer.rs b/src/buffer.rs index b1b3cb08..08235944 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -81,7 +81,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { where D: OnNewBuffer<'a, T, D, S>, { - let data = device.base_to_data(base); + let data = device.default_base_to_data(base); let mut buf = Buffer { data, device: Some(device), @@ -260,12 +260,12 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { /// assert_eq!(buf.as_slice(), &[0, 1, 2, 3, 4]); /// ``` #[inline] - pub fn deviceless<'b: 'a>(device: &'b D, len: usize) -> Buffer<'a, T, D, S> + pub fn deviceless<'b>(device: &'b D, len: usize) -> Buffer<'a, T, D, S> where D: DevicelessAble<'b, T, S>, { Buffer { - data: device.base_to_data(device.alloc(len, AllocFlag::None).unwrap()), + data: device.default_base_to_data_unbound(device.alloc(len, AllocFlag::None).unwrap()), device: None, } } @@ -286,7 +286,8 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { let mut base = unsafe { self.base().shallow() }; unsafe { base.set_flag(AllocFlag::None) }; - let data: ::Data<'b, T, S> = self.device().base_to_data_unbound::(base); + let data: ::Data<'b, T, S> = + self.device().default_base_to_data_unbound::(base); Buffer { data, device: None } } diff --git a/src/buffer/num.rs b/src/buffer/num.rs index b003b190..f04eebf0 100644 --- a/src/buffer/num.rs +++ b/src/buffer/num.rs @@ -61,15 +61,15 @@ impl Device for () { } #[inline(always)] - fn base_to_data<'a, T: Unit, S: crate::Shape>( + fn default_base_to_data<'a, T: Unit, S: crate::Shape>( &'a self, base: Self::Base, ) -> Self::Data<'a, T, S> { base } - + #[inline(always)] - fn base_to_data_unbound<'a, T: Unit, S: crate::Shape>( + fn default_base_to_data_unbound<'a, T: Unit, S: crate::Shape>( &self, base: Self::Base, ) -> Self::Data<'a, T, S> { @@ -124,9 +124,12 @@ impl WrappedData for () { fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> { base } - + #[inline] - fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base> { base } diff --git a/src/cache/borrow_cache.rs b/src/cache/borrow_cache.rs index 0bfda0ed..49b85b8a 100644 --- a/src/cache/borrow_cache.rs +++ b/src/cache/borrow_cache.rs @@ -99,7 +99,9 @@ impl BorrowCache { // not using ::new, because this buf would get added to the cache of the device. // not anymore ? let buf: Buffer = Buffer { - data: device.base_to_data_unbound(device.alloc::(id.len, AllocFlag::BorrowedCache).unwrap()), + data: device.default_base_to_data_unbound( + device.alloc::(id.len, AllocFlag::BorrowedCache).unwrap(), + ), device: None, }; diff --git a/src/devices.rs b/src/devices.rs index db09e7c4..4fc49f72 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -58,8 +58,14 @@ pub trait Device: OnDropBuffer + Sized { // add default impl if GAT default go stable // FIXME: probably a better way to realize these - fn base_to_data<'a, T: Unit, S: Shape>(&'a self, base: Self::Base) -> Self::Data<'a, T, S>; - fn base_to_data_unbound<'a, T: Unit, S: Shape>(&self, base: Self::Base) -> Self::Data<'a, T, S>; + fn default_base_to_data<'a, T: Unit, S: Shape>( + &'a self, + base: Self::Base, + ) -> Self::Data<'a, T, S>; + fn default_base_to_data_unbound<'a, T: Unit, S: Shape>( + &self, + base: Self::Base, + ) -> Self::Data<'a, T, S>; fn wrap_to_data<'a, T: Unit, S: Shape>( &self, wrap: Self::Wrap<'a, T, Self::Base>, @@ -115,7 +121,11 @@ macro_rules! impl_buffer_hook_traits { Self: 'dev, { #[inline] - unsafe fn on_new_buffer(&'dev self, device: &'dev D, new_buf: &mut Buffer<'dev, T, D, S>) { + unsafe fn on_new_buffer( + &'dev self, + device: &'dev D, + new_buf: &mut Buffer<'dev, T, D, S>, + ) { unsafe { self.modules.on_new_buffer(device, new_buf) } } } diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index f94ce714..160d5c5f 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -43,12 +43,18 @@ impl Device for CPU { } #[inline(always)] - fn base_to_data<'a, T: Unit, S: Shape>(&'a self, base: Self::Base) -> Self::Data<'a, T, S> { + fn default_base_to_data<'a, T: Unit, S: Shape>( + &'a self, + base: Self::Base, + ) -> Self::Data<'a, T, S> { self.wrap_in_base(base) } - + #[inline(always)] - fn base_to_data_unbound<'a, T: Unit, S: Shape>(&self, base: Self::Base) -> Self::Data<'a, T, S> { + fn default_base_to_data_unbound<'a, T: Unit, S: Shape>( + &self, + base: Self::Base, + ) -> Self::Data<'a, T, S> { self.wrap_in_base_unbound(base) } @@ -73,7 +79,7 @@ impl Device for CPU { ) -> &'b mut Self::Wrap<'a, T, Self::Base> { data } - + // #[inline] // fn wrap(&self) {} } diff --git a/src/hooks.rs b/src/hooks.rs index 32d23e15..3ac442bf 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -8,5 +8,10 @@ pub trait OnDropBuffer: WrappedData { pub trait OnNewBuffer<'dev, T: Unit, D: Device, S: Shape = ()> { #[track_caller] - unsafe fn on_new_buffer<'s>(&'dev self, _device: &'dev D, _new_buf: &'s mut Buffer<'dev, T, D, S>) {} + unsafe fn on_new_buffer<'s>( + &'dev self, + _device: &'dev D, + _new_buf: &'s mut Buffer<'dev, T, D, S>, + ) { + } } diff --git a/src/modules/autograd/wrapper.rs b/src/modules/autograd/wrapper.rs index 055741df..89280bc4 100644 --- a/src/modules/autograd/wrapper.rs +++ b/src/modules/autograd/wrapper.rs @@ -1,7 +1,8 @@ use core::{fmt::Debug, marker::PhantomData}; use crate::{ - flag::AllocFlag, Autograd, Device, HasId, IsBasePtr, PtrType, ShallowCopy, Shape, ToBase, ToDim, UniqueId, Unit, WrappedData + flag::AllocFlag, Autograd, Device, HasId, IsBasePtr, PtrType, ShallowCopy, Shape, ToBase, + ToDim, UniqueId, Unit, WrappedData, }; // #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -42,14 +43,18 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> { requires_grad: true, data: self.modules.wrap_in_base(base), remove_id_cb: Some(Box::new(|id| { - unsafe { &mut (*self.grads.get()).no_grads_pool }.remove(&id); + unsafe { (*self.grads.get()).buf_requires_grad.remove(&id) }; + unsafe { (*self.grads.get()).no_grads_pool.remove(&id) }; })), _pd: PhantomData, } } - + #[inline] - fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base> { ReqGradWrapper { // by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change requires_grad: true, diff --git a/src/modules/base.rs b/src/modules/base.rs index c881ec8c..583dd27a 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -32,8 +32,11 @@ impl WrappedData for Base { wrap } - #[inline] - fn wrap_in_base_unbound<'a, T: Unit, Base: crate::IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + #[inline] + fn wrap_in_base_unbound<'a, T: Unit, Base: crate::IsBasePtr>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base> { base } } diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 425c5185..38c86e62 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -5,7 +5,11 @@ use core::{ }; use crate::{ - AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, Downcast, ExecNow, FastCache2, Guard, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, WrappedData2, WrappedData3 + AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, + Downcast, ExecNow, FastCache2, Guard, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, + OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, + SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, WrappedData2, + WrappedData3, }; #[cfg(feature = "graph")] @@ -19,14 +23,16 @@ pub struct Cached { cache_type: PhantomData, } -impl<'w, CacheType: 'static, Mods: WrappedData2<'w> + 'static, SD: Device + 'static> WrappedData2<'w> for CachedModule { +impl<'w, CacheType: 'static, Mods: WrappedData2<'w> + 'static, SD: Device + 'static> + WrappedData2<'w> for CachedModule +{ type Wrap<'a, T: Unit, Base: IsBasePtr> = Guard<'a, Mods::Wrap<'a, T, Base>>; #[inline] fn wrap_in_base(&'w self, base: Base) -> Self::Wrap<'w, T, Base> { Guard::new(CowMut::Owned(self.modules.wrap_in_base(base))) } - + #[inline] fn wrap_in_base2<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { todo!() @@ -34,16 +40,21 @@ impl<'w, CacheType: 'static, Mods: WrappedData2<'w> + 'static, SD: Device + 'sta } } -impl WrappedData for CachedModule { +impl WrappedData + for CachedModule +{ type Wrap<'a, T: Unit, Base: IsBasePtr> = Guard<'a, Mods::Wrap<'a, T, Base>>; #[inline] fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> { Guard::new(CowMut::Owned(self.modules.wrap_in_base(base))) } - - #[inline] - fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + + #[inline] + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base> { Guard::new(CowMut::Owned(self.modules.wrap_in_base_unbound(base))) } @@ -59,7 +70,7 @@ impl WrappedData for CachedMo wrap: &'b mut Self::Wrap<'a, T, Base>, ) -> &'b mut Base { Mods::wrapped_as_base_mut(wrap) - } + } } impl<'a, CacheType, Mods: Module<'a, D>, D: Device + 'a> Module<'a, D> for Cached @@ -161,7 +172,9 @@ where } } -impl OnDropBuffer for CachedModule { +impl OnDropBuffer + for CachedModule +{ #[inline] fn on_drop_buffer(&self, device: &D, buf: &Buffer) { self.modules.on_drop_buffer(device, buf) @@ -187,9 +200,7 @@ where let entry = self.cache.get_mut(id, len)?; let mut entry = RefMut::map(entry, |x| { if x.is::>>() { - unsafe { - Downcast::downcast_mut_unchecked::>>(x) - } + unsafe { Downcast::downcast_mut_unchecked::>>(x) } } else { panic!() } @@ -482,7 +493,8 @@ impl CachedBuffers for CachedModule ReplaceBuf for CachedModule +impl ReplaceBuf + for CachedModule where T: Unit, Mods: ReplaceBuf, diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 52eb74d5..de0317f0 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -350,7 +350,7 @@ where let base = device .alloc::(id.len, crate::flag::AllocFlag::Lazy) .unwrap(); - let data = device.base_to_data(base); + let data = device.default_base_to_data(base); let buffer = Buffer { data, device: Some(device), diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 969b6ded..f1fd1bb3 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -27,9 +27,12 @@ impl WrappedData for Lazy<'_, Mods, T2> { _pd: PhantomData, } } - + #[inline] - fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base> { LazyWrapper { maybe_data: MaybeData::Data(self.modules.wrap_in_base_unbound(base)), _pd: PhantomData, diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 53c9bacb..c6ad3dc9 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -99,7 +99,9 @@ pub(crate) unsafe fn register_buf_any<'a, T, D, S>( { // shallow copy sets flag to AllocFlag::Wrapper let wrapped_data = unsafe { buf.base().shallow() }; - let data: ::Data<'static, T, S> = buf.device().base_to_data_unbound::(wrapped_data); + let data: ::Data<'static, T, S> = buf + .device() + .default_base_to_data_unbound::(wrapped_data); let buf: Buffer<'static, T, D, S> = Buffer { data, device: None }; cache.insert(*buf.id(), Box::new(buf)); @@ -130,7 +132,9 @@ pub(crate) unsafe fn register_buf_copyable<'a, T, D, S>( { // shallow copy sets flag to AllocFlag::Wrapper let wrapped_data = unsafe { buf.base().shallow() }; - let data: ::Data<'static, T, S> = buf.device().base_to_data_unbound::(wrapped_data); + let data: ::Data<'static, T, S> = buf + .device() + .default_base_to_data_unbound::(wrapped_data); let buf: Buffer<'static, T, D, S> = Buffer { data, device: None }; cache.insert(*buf.id(), Box::new(buf)); diff --git a/src/wrapper.rs b/src/wrapper.rs index 63cc2fd1..05a6007d 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -1,24 +1,33 @@ use crate::{HasId, IsBasePtr, PtrType, Unit}; pub trait WrappedData2<'w> { - type Wrap<'a, T: Unit, Base: IsBasePtr>: PtrType + HasId + 'a where Self: 'w, Self: 'a; + type Wrap<'a, T: Unit, Base: IsBasePtr>: PtrType + HasId + 'a + where + Self: 'w, + Self: 'a; fn wrap_in_base(&'w self, base: Base) -> Self::Wrap<'w, T, Base>; fn wrap_in_base2<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base>; -} +} pub trait WrappedData3<'w> { - type Wrap<'a, T: Unit, Base: IsBasePtr>: PtrType + HasId + 'a where Self: 'a; + type Wrap<'a, T: Unit, Base: IsBasePtr>: PtrType + HasId + 'a + where + Self: 'a; fn wrap_in_base(&'w self, base: Base) -> Self::Wrap<'w, T, Base>; - fn wrap_in_base2<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base>; -} + fn wrap_in_base2<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) + -> Self::Wrap<'a, T, Base>; +} pub trait WrappedData { type Wrap<'a, T: Unit, Base: IsBasePtr>: PtrType + HasId + 'a; fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base>; - fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base>; + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>( + &self, + base: Base, + ) -> Self::Wrap<'a, T, Base>; #[track_caller] fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>( wrap: &'b Self::Wrap<'a, T, Base>, @@ -42,7 +51,7 @@ macro_rules! impl_wrapped_data { ) -> Self::Wrap<'a, T, Base> { self.modules.wrap_in_base(base) } - + #[inline] fn wrap_in_base_unbound<'a, T: Unit, Base: $crate::IsBasePtr>( &self,