Skip to content

Commit

Permalink
Rename base_to_data to default_base_to_data
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 18, 2024
1 parent e72eb72 commit af91dba
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 50 deletions.
10 changes: 9 additions & 1 deletion examples/custom_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ impl<Mods: WrappedData> WrappedData for CustomModule<Mods> {
type Wrap<'a, T: Unit, Base: IsBasePtr> = Mods::Wrap<'a, T, Base>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> {
self.mods.wrap_in_base(base)
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
self.mods.wrap_in_base_unbound(base)
}

#[inline]
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b Self::Wrap<'a, T, Base>,
Expand Down
9 changes: 5 additions & 4 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
where
D: OnNewBuffer<'a, T, D, S>,
{
let data = device.base_to_data(base);
let data = device.default_base_to_data(base);
let mut buf = Buffer {
data,
device: Some(device),
Expand Down Expand Up @@ -260,12 +260,12 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
/// assert_eq!(buf.as_slice(), &[0, 1, 2, 3, 4]);
/// ```
#[inline]
pub fn deviceless<'b: 'a>(device: &'b D, len: usize) -> Buffer<'a, T, D, S>
pub fn deviceless<'b>(device: &'b D, len: usize) -> Buffer<'a, T, D, S>
where
D: DevicelessAble<'b, T, S>,
{
Buffer {
data: device.base_to_data(device.alloc(len, AllocFlag::None).unwrap()),
data: device.default_base_to_data_unbound(device.alloc(len, AllocFlag::None).unwrap()),
device: None,
}
}
Expand All @@ -286,7 +286,8 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
let mut base = unsafe { self.base().shallow() };
unsafe { base.set_flag(AllocFlag::None) };

let data: <D as Device>::Data<'b, T, S> = self.device().base_to_data_unbound::<T, S>(base);
let data: <D as Device>::Data<'b, T, S> =
self.device().default_base_to_data_unbound::<T, S>(base);

Buffer { data, device: None }
}
Expand Down
13 changes: 8 additions & 5 deletions src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ impl Device for () {
}

#[inline(always)]
fn base_to_data<'a, T: Unit, S: crate::Shape>(
fn default_base_to_data<'a, T: Unit, S: crate::Shape>(
&'a self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S> {
base
}

#[inline(always)]
fn base_to_data_unbound<'a, T: Unit, S: crate::Shape>(
fn default_base_to_data_unbound<'a, T: Unit, S: crate::Shape>(
&self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S> {
Expand Down Expand Up @@ -124,9 +124,12 @@ impl WrappedData for () {
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> {
base
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
base
}

Expand Down
4 changes: 3 additions & 1 deletion src/cache/borrow_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ impl BorrowCache {
// not using ::new, because this buf would get added to the cache of the device.
// not anymore ?
let buf: Buffer<T, D, S> = Buffer {
data: device.base_to_data_unbound(device.alloc::<S>(id.len, AllocFlag::BorrowedCache).unwrap()),
data: device.default_base_to_data_unbound(
device.alloc::<S>(id.len, AllocFlag::BorrowedCache).unwrap(),
),
device: None,
};

Expand Down
16 changes: 13 additions & 3 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,14 @@ pub trait Device: OnDropBuffer + Sized {

// add default impl if GAT default go stable
// FIXME: probably a better way to realize these
fn base_to_data<'a, T: Unit, S: Shape>(&'a self, base: Self::Base<T, S>) -> Self::Data<'a, T, S>;
fn base_to_data_unbound<'a, T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S>;
fn default_base_to_data<'a, T: Unit, S: Shape>(
&'a self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S>;
fn default_base_to_data_unbound<'a, T: Unit, S: Shape>(
&self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S>;
fn wrap_to_data<'a, T: Unit, S: Shape>(
&self,
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
Expand Down Expand Up @@ -115,7 +121,11 @@ macro_rules! impl_buffer_hook_traits {
Self: 'dev,
{
#[inline]
unsafe fn on_new_buffer(&'dev self, device: &'dev D, new_buf: &mut Buffer<'dev, T, D, S>) {
unsafe fn on_new_buffer(
&'dev self,
device: &'dev D,
new_buf: &mut Buffer<'dev, T, D, S>,
) {
unsafe { self.modules.on_new_buffer(device, new_buf) }
}
}
Expand Down
14 changes: 10 additions & 4 deletions src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,18 @@ impl<Mods: OnDropBuffer> Device for CPU<Mods> {
}

#[inline(always)]
fn base_to_data<'a, T: Unit, S: Shape>(&'a self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
fn default_base_to_data<'a, T: Unit, S: Shape>(
&'a self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S> {
self.wrap_in_base(base)
}

#[inline(always)]
fn base_to_data_unbound<'a, T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
fn default_base_to_data_unbound<'a, T: Unit, S: Shape>(
&self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S> {
self.wrap_in_base_unbound(base)
}

Expand All @@ -73,7 +79,7 @@ impl<Mods: OnDropBuffer> Device for CPU<Mods> {
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

// #[inline]
// fn wrap(&self) {}
}
Expand Down
7 changes: 6 additions & 1 deletion src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,10 @@ pub trait OnDropBuffer: WrappedData {

pub trait OnNewBuffer<'dev, T: Unit, D: Device, S: Shape = ()> {
#[track_caller]
unsafe fn on_new_buffer<'s>(&'dev self, _device: &'dev D, _new_buf: &'s mut Buffer<'dev, T, D, S>) {}
unsafe fn on_new_buffer<'s>(
&'dev self,
_device: &'dev D,
_new_buf: &'s mut Buffer<'dev, T, D, S>,
) {
}
}
13 changes: 9 additions & 4 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use core::{fmt::Debug, marker::PhantomData};

use crate::{
flag::AllocFlag, Autograd, Device, HasId, IsBasePtr, PtrType, ShallowCopy, Shape, ToBase, ToDim, UniqueId, Unit, WrappedData
flag::AllocFlag, Autograd, Device, HasId, IsBasePtr, PtrType, ShallowCopy, Shape, ToBase,
ToDim, UniqueId, Unit, WrappedData,
};

// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
Expand Down Expand Up @@ -42,14 +43,18 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
requires_grad: true,
data: self.modules.wrap_in_base(base),
remove_id_cb: Some(Box::new(|id| {
unsafe { &mut (*self.grads.get()).no_grads_pool }.remove(&id);
unsafe { (*self.grads.get()).buf_requires_grad.remove(&id) };
unsafe { (*self.grads.get()).no_grads_pool.remove(&id) };
})),
_pd: PhantomData,
}
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
ReqGradWrapper {
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
requires_grad: true,
Expand Down
7 changes: 5 additions & 2 deletions src/modules/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ impl WrappedData for Base {
wrap
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: crate::IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: crate::IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
base
}
}
Expand Down
38 changes: 25 additions & 13 deletions src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ use core::{
};

use crate::{
AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, Downcast, ExecNow, FastCache2, Guard, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, WrappedData2, WrappedData3
AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device,
Downcast, ExecNow, FastCache2, Guard, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module,
OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule,
SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, WrappedData2,
WrappedData3,
};

#[cfg(feature = "graph")]
Expand All @@ -19,31 +23,38 @@ pub struct Cached<Mods, CacheType = FastCache2> {
cache_type: PhantomData<CacheType>,
}

impl<'w, CacheType: 'static, Mods: WrappedData2<'w> + 'static, SD: Device + 'static> WrappedData2<'w> for CachedModule<Mods, SD, CacheType> {
impl<'w, CacheType: 'static, Mods: WrappedData2<'w> + 'static, SD: Device + 'static>
WrappedData2<'w> for CachedModule<Mods, SD, CacheType>
{
type Wrap<'a, T: Unit, Base: IsBasePtr> = Guard<'a, Mods::Wrap<'a, T, Base>>;

#[inline]
fn wrap_in_base<T: Unit, Base: IsBasePtr>(&'w self, base: Base) -> Self::Wrap<'w, T, Base> {
Guard::new(CowMut::Owned(self.modules.wrap_in_base(base)))
}

#[inline]
fn wrap_in_base2<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
todo!()
// Guard::new(CowMut::Owned(self.modules.wrap_in_base(base)))
}
}

impl<CacheType: 'static, Mods: WrappedData, SD: Device> WrappedData for CachedModule<Mods, SD, CacheType> {
impl<CacheType: 'static, Mods: WrappedData, SD: Device> WrappedData
for CachedModule<Mods, SD, CacheType>
{
type Wrap<'a, T: Unit, Base: IsBasePtr> = Guard<'a, Mods::Wrap<'a, T, Base>>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> {
Guard::new(CowMut::Owned(self.modules.wrap_in_base(base)))
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
Guard::new(CowMut::Owned(self.modules.wrap_in_base_unbound(base)))
}

Expand All @@ -59,7 +70,7 @@ impl<CacheType: 'static, Mods: WrappedData, SD: Device> WrappedData for CachedMo
wrap: &'b mut Self::Wrap<'a, T, Base>,
) -> &'b mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}
}

impl<'a, CacheType, Mods: Module<'a, D>, D: Device + 'a> Module<'a, D> for Cached<Mods, CacheType>
Expand Down Expand Up @@ -161,7 +172,9 @@ where
}
}

impl<CacheType: 'static, Mods: OnDropBuffer, SD: Device> OnDropBuffer for CachedModule<Mods, SD, CacheType> {
impl<CacheType: 'static, Mods: OnDropBuffer, SD: Device> OnDropBuffer
for CachedModule<Mods, SD, CacheType>
{
#[inline]
fn on_drop_buffer<T: Unit, D: Device, S: Shape>(&self, device: &D, buf: &Buffer<T, D, S>) {
self.modules.on_drop_buffer(device, buf)
Expand All @@ -187,9 +200,7 @@ where
let entry = self.cache.get_mut(id, len)?;
let mut entry = RefMut::map(entry, |x| {
if x.is::<Mods::Wrap<'static, T, D::Base<T, S>>>() {
unsafe {
Downcast::downcast_mut_unchecked::<Mods::Wrap<'a, T, D::Base<T, S>>>(x)
}
unsafe { Downcast::downcast_mut_unchecked::<Mods::Wrap<'a, T, D::Base<T, S>>>(x) }
} else {
panic!()
}
Expand Down Expand Up @@ -482,7 +493,8 @@ impl<CacheType, Mods: OnDropBuffer, D: Device> CachedBuffers for CachedModule<Mo
}
}

impl<CacheType: 'static, Mods, D, T, S, SD> ReplaceBuf<T, D, S> for CachedModule<Mods, SD, CacheType>
impl<CacheType: 'static, Mods, D, T, S, SD> ReplaceBuf<T, D, S>
for CachedModule<Mods, SD, CacheType>
where
T: Unit,
Mods: ReplaceBuf<T, D, S>,
Expand Down
2 changes: 1 addition & 1 deletion src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ where
let base = device
.alloc::<S>(id.len, crate::flag::AllocFlag::Lazy)
.unwrap();
let data = device.base_to_data(base);
let data = device.default_base_to_data(base);
let buffer = Buffer {
data,
device: Some(device),
Expand Down
7 changes: 5 additions & 2 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ impl<T2, Mods: WrappedData> WrappedData for Lazy<'_, Mods, T2> {
_pd: PhantomData,
}
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
LazyWrapper {
maybe_data: MaybeData::Data(self.modules.wrap_in_base_unbound(base)),
_pd: PhantomData,
Expand Down
8 changes: 6 additions & 2 deletions src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ pub(crate) unsafe fn register_buf_any<'a, T, D, S>(
{
// shallow copy sets flag to AllocFlag::Wrapper
let wrapped_data = unsafe { buf.base().shallow() };
let data: <D as Device>::Data<'static, T, S> = buf.device().base_to_data_unbound::<T, S>(wrapped_data);
let data: <D as Device>::Data<'static, T, S> = buf
.device()
.default_base_to_data_unbound::<T, S>(wrapped_data);

let buf: Buffer<'static, T, D, S> = Buffer { data, device: None };
cache.insert(*buf.id(), Box::new(buf));
Expand Down Expand Up @@ -130,7 +132,9 @@ pub(crate) unsafe fn register_buf_copyable<'a, T, D, S>(
{
// shallow copy sets flag to AllocFlag::Wrapper
let wrapped_data = unsafe { buf.base().shallow() };
let data: <D as Device>::Data<'static, T, S> = buf.device().base_to_data_unbound::<T, S>(wrapped_data);
let data: <D as Device>::Data<'static, T, S> = buf
.device()
.default_base_to_data_unbound::<T, S>(wrapped_data);

let buf: Buffer<'static, T, D, S> = Buffer { data, device: None };
cache.insert(*buf.id(), Box::new(buf));
Expand Down
Loading

0 comments on commit af91dba

Please sign in to comment.