Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct copy specialization #65

Merged
merged 16 commits into from
Apr 24, 2024
4 changes: 4 additions & 0 deletions derive/impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
alignment: struct_alignment,
has_uniform_min_alignment: true,
min_size,
is_pod: false,
extra,
}
};
Expand All @@ -614,6 +615,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
#( for<'__> #field_types_2: #root::WriteInto, )*
{
#[inline]
fn write_into<B: #root::BufferMut>(&self, writer: &mut #root::Writer<B>) {
#set_contained_rt_sized_array_length
#( #write_into_buffer_body )*
Expand All @@ -625,6 +627,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
#( for<'__> #field_types_3: #root::ReadFrom, )*
{
#[inline]
fn read_from<B: #root::BufferRef>(&mut self, reader: &mut #root::Reader<B>) {
#( #read_from_buffer_body )*
}
Expand All @@ -635,6 +638,7 @@ pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
#( for<'__> #field_types_4: #root::CreateFrom, )*
{
#[inline]
fn create_from<B: #root::BufferRef>(reader: &mut #root::Reader<B>) -> Self {
#( #create_from_buffer_body )*

Expand Down
66 changes: 66 additions & 0 deletions src/core/rw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ impl<B: BufferMut> Writer<B> {
pub fn write<const N: usize>(&mut self, val: &[u8; N]) {
self.cursor.write(val)
}

#[inline]
pub fn write_slice(&mut self, val: &[u8]) {
self.cursor.write_slice(val)
}
}

pub struct ReadContext {
Expand Down Expand Up @@ -93,6 +98,11 @@ impl<B: BufferRef> Reader<B> {
self.cursor.read()
}

#[inline]
pub fn read_slice(&mut self, val: &mut [u8]) {
self.cursor.read_slice(val)
}

#[inline]
pub fn remaining(&self) -> usize {
self.cursor.remaining()
Expand Down Expand Up @@ -130,6 +140,12 @@ impl<B: BufferRef> Cursor<B> {
self.pos += N;
res
}

#[inline]
fn read_slice(&mut self, val: &mut [u8]) {
self.buffer.read_slice(self.pos, val);
self.pos += val.len();
}
}

impl<B: BufferMut> Cursor<B> {
Expand All @@ -144,6 +160,12 @@ impl<B: BufferMut> Cursor<B> {
self.pos += N;
}

#[inline]
fn write_slice(&mut self, val: &[u8]) {
self.buffer.write_slice(self.pos, val);
self.pos += val.len();
}

#[inline]
fn try_enlarge(&mut self, wanted: usize) -> core::result::Result<(), EnlargeError> {
self.buffer.try_enlarge(wanted)
Expand All @@ -165,13 +187,17 @@ pub trait BufferRef {
fn len(&self) -> usize;

fn read<const N: usize>(&self, offset: usize) -> &[u8; N];

fn read_slice(&self, offset: usize, val: &mut [u8]);
}

pub trait BufferMut {
fn capacity(&self) -> usize;

fn write<const N: usize>(&mut self, offset: usize, val: &[u8; N]);

fn write_slice(&mut self, offset: usize, val: &[u8]);

#[inline]
fn try_enlarge(&mut self, wanted: usize) -> core::result::Result<(), EnlargeError> {
if wanted > self.capacity() {
Expand All @@ -192,6 +218,11 @@ impl BufferRef for [u8] {
use crate::utils::SliceExt;
self.array(offset)
}

#[inline]
fn read_slice(&self, offset: usize, val: &mut [u8]) {
val.copy_from_slice(&self[offset..offset + val.len()])
}
}

impl<const LEN: usize> BufferRef for [u8; LEN] {
Expand All @@ -204,6 +235,11 @@ impl<const LEN: usize> BufferRef for [u8; LEN] {
fn read<const N: usize>(&self, offset: usize) -> &[u8; N] {
<[u8] as BufferRef>::read(self, offset)
}

#[inline]
fn read_slice(&self, offset: usize, val: &mut [u8]) {
<[u8] as BufferRef>::read_slice(self, offset, val)
}
}

impl BufferRef for Vec<u8> {
Expand All @@ -216,6 +252,11 @@ impl BufferRef for Vec<u8> {
fn read<const N: usize>(&self, offset: usize) -> &[u8; N] {
<[u8] as BufferRef>::read(self, offset)
}

#[inline]
fn read_slice(&self, offset: usize, val: &mut [u8]) {
<[u8] as BufferRef>::read_slice(self, offset, val)
}
}

impl BufferMut for [u8] {
Expand All @@ -229,6 +270,11 @@ impl BufferMut for [u8] {
use crate::utils::SliceExt;
*self.array_mut(offset) = *val;
}

#[inline]
fn write_slice(&mut self, offset: usize, val: &[u8]) {
self[offset..offset + val.len()].copy_from_slice(val);
}
}

impl<const LEN: usize> BufferMut for [u8; LEN] {
Expand All @@ -241,6 +287,11 @@ impl<const LEN: usize> BufferMut for [u8; LEN] {
fn write<const N: usize>(&mut self, offset: usize, val: &[u8; N]) {
<[u8] as BufferMut>::write(self, offset, val)
}

#[inline]
fn write_slice(&mut self, offset: usize, val: &[u8]) {
<[u8] as BufferMut>::write_slice(self, offset, val)
}
}

impl BufferMut for Vec<u8> {
Expand All @@ -254,6 +305,11 @@ impl BufferMut for Vec<u8> {
<[u8] as BufferMut>::write(self, offset, val)
}

#[inline]
fn write_slice(&mut self, offset: usize, val: &[u8]) {
<[u8] as BufferMut>::write_slice(self, offset, val)
}

#[inline]
fn try_enlarge(&mut self, wanted: usize) -> core::result::Result<(), EnlargeError> {
use crate::utils::ByteVecExt;
Expand All @@ -273,6 +329,11 @@ macro_rules! impl_buffer_ref_for_wrappers {
fn read<const N: usize>(&self, offset: usize) -> &[u8; N] {
T::read(self, offset)
}

#[inline]
fn read_slice(&self, offset: usize, val: &mut [u8]) {
T::read_slice(self, offset, val)
}
}
)*};
}
Expand All @@ -292,6 +353,11 @@ macro_rules! impl_buffer_mut_for_wrappers {
T::write(self, offset, val)
}

#[inline]
fn write_slice(&mut self, offset: usize, val: &[u8]) {
T::write_slice(self, offset, val)
}

#[inline]
fn try_enlarge(&mut self, wanted: usize) -> core::result::Result<(), EnlargeError> {
T::try_enlarge(self, wanted)
Expand Down
21 changes: 21 additions & 0 deletions src/core/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct Metadata<E> {
pub alignment: AlignmentValue,
pub has_uniform_min_alignment: bool,
pub min_size: SizeValue,
pub is_pod: bool,
pub extra: E,
}

Expand All @@ -17,6 +18,7 @@ impl Metadata<()> {
alignment: AlignmentValue::new(alignment),
has_uniform_min_alignment: false,
min_size: SizeValue::new(size),
is_pod: false,
extra: (),
}
}
Expand Down Expand Up @@ -49,6 +51,25 @@ impl<E> Metadata<E> {
core::mem::forget(self);
value
}

#[inline]
pub const fn is_pod(self) -> bool {
let value = self.is_pod;
core::mem::forget(self);
value
}

#[inline]
pub const fn pod(mut self) -> Self {
self.is_pod = true;
self
}

#[inline]
pub const fn no_pod(mut self) -> Self {
self.is_pod = false;
self
}
}

/// Base trait for all [WGSL host-shareable types](https://gpuweb.github.io/gpuweb/wgsl/#host-shareable-types)
Expand Down
50 changes: 38 additions & 12 deletions src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ impl<T: ShaderType + ShaderSize, const N: usize> ShaderType for [T; N] {
alignment,
has_uniform_min_alignment: true,
min_size: size,
is_pod: T::METADATA.is_pod() && el_padding == 0,
extra: ArrayMetadata { stride, el_padding },
}
};
Expand Down Expand Up @@ -65,10 +66,17 @@ where
{
#[inline]
fn write_into<B: BufferMut>(&self, writer: &mut Writer<B>) {
for item in self {
WriteInto::write_into(item, writer);
writer.advance(Self::METADATA.el_padding() as usize);
}
if_pod_and_little_endian!(if pod_and_little_endian {
let ptr = self.as_ptr() as *const u8;
let byte_slice: &[u8] =
unsafe { core::slice::from_raw_parts(ptr, core::mem::size_of::<Self>()) };
writer.write_slice(byte_slice);
} else {
for elem in self {
WriteInto::write_into(elem, writer);
writer.advance(Self::METADATA.el_padding() as usize);
}
});
}
}

Expand All @@ -78,10 +86,17 @@ where
{
#[inline]
fn read_from<B: BufferRef>(&mut self, reader: &mut Reader<B>) {
for elem in self {
ReadFrom::read_from(elem, reader);
reader.advance(Self::METADATA.el_padding() as usize);
}
if_pod_and_little_endian!(if pod_and_little_endian {
let ptr = self.as_mut_ptr() as *mut u8;
let byte_slice: &mut [u8] =
unsafe { core::slice::from_raw_parts_mut(ptr, core::mem::size_of::<Self>()) };
reader.read_slice(byte_slice);
} else {
for elem in self {
ReadFrom::read_from(elem, reader);
reader.advance(Self::METADATA.el_padding() as usize);
}
});
}
}

Expand All @@ -91,10 +106,21 @@ where
{
#[inline]
fn create_from<B: BufferRef>(reader: &mut Reader<B>) -> Self {
core::array::from_fn(|_| {
let res = CreateFrom::create_from(reader);
reader.advance(Self::METADATA.el_padding() as usize);
res
if_pod_and_little_endian!(if pod_and_little_endian {
let mut me = core::mem::MaybeUninit::zeroed();
let ptr: *mut core::mem::MaybeUninit<Self> = &mut me;
let ptr = ptr.cast::<u8>();
let byte_slice: &mut [u8] =
unsafe { core::slice::from_raw_parts_mut(ptr, core::mem::size_of::<Self>()) };
reader.read_slice(byte_slice);
// SAFETY: All values were properly initialized by reading the bytes.
unsafe { me.assume_init() }
} else {
core::array::from_fn(|_| {
let res = CreateFrom::create_from(reader);
reader.advance(Self::METADATA.el_padding() as usize);
res
})
})
}
}
44 changes: 26 additions & 18 deletions src/types/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::core::Metadata;

pub trait MatrixScalar {}
pub trait MatrixScalar: crate::ShaderSize {}
impl_marker_trait_for_f32!(MatrixScalar);

pub struct MatrixMetadata {
Expand Down Expand Up @@ -142,6 +142,7 @@ macro_rules! impl_matrix_inner {
alignment,
has_uniform_min_alignment: false,
min_size: size,
is_pod: <[$el_ty; $r] as $crate::private::ShaderType>::METADATA.is_pod() && col_padding == 0,
extra: $crate::private::MatrixMetadata {
col_padding,
},
Expand All @@ -162,12 +163,15 @@ macro_rules! impl_matrix_inner {
#[inline]
fn write_into<B: $crate::private::BufferMut>(&self, writer: &mut $crate::private::Writer<B>) {
let columns = $crate::private::AsRefMatrixParts::<$el_ty, $c, $r>::as_ref_parts(self);
for col in columns {
for el in col {
$crate::private::WriteInto::write_into(el, writer);

$crate::if_pod_and_little_endian!(if pod_and_little_endian {
$crate::private::WriteInto::write_into(columns, writer);
} else {
for col in columns {
$crate::private::WriteInto::write_into(col, writer);
writer.advance(<Self as $crate::private::ShaderType>::METADATA.col_padding() as ::core::primitive::usize);
}
writer.advance(<Self as $crate::private::ShaderType>::METADATA.col_padding() as ::core::primitive::usize);
}
});
}
}

Expand All @@ -179,12 +183,15 @@ macro_rules! impl_matrix_inner {
#[inline]
fn read_from<B: $crate::private::BufferRef>(&mut self, reader: &mut $crate::private::Reader<B>) {
let columns = $crate::private::AsMutMatrixParts::<$el_ty, $c, $r>::as_mut_parts(self);
for col in columns {
for el in col {
$crate::private::ReadFrom::read_from(el, reader);

$crate::if_pod_and_little_endian!(if pod_and_little_endian {
$crate::private::ReadFrom::read_from(columns, reader);
} else {
for col in columns {
$crate::private::ReadFrom::read_from(col, reader);
reader.advance(<Self as $crate::private::ShaderType>::METADATA.col_padding() as ::core::primitive::usize);
}
reader.advance(<Self as $crate::private::ShaderType>::METADATA.col_padding() as ::core::primitive::usize);
}
});
}
}

Expand All @@ -195,14 +202,15 @@ macro_rules! impl_matrix_inner {
{
#[inline]
fn create_from<B: $crate::private::BufferRef>(reader: &mut $crate::private::Reader<B>) -> Self {
let columns = ::core::array::from_fn(|_| {
let col = ::core::array::from_fn(|_| {
$crate::private::CreateFrom::create_from(reader)
});
reader.advance(<Self as $crate::private::ShaderType>::METADATA.col_padding() as ::core::primitive::usize);
col
let columns = $crate::if_pod_and_little_endian!(if pod_and_little_endian {
$crate::private::CreateFrom::create_from(reader)
} else {
::core::array::from_fn(|_| {
let col = $crate::private::CreateFrom::create_from(reader);
reader.advance(<Self as $crate::private::ShaderType>::METADATA.col_padding() as ::core::primitive::usize);
col
})
});

$crate::private::FromMatrixParts::<$el_ty, $c, $r>::from_parts(columns)
}
}
Expand Down
Loading