Skip to content

Commit

Permalink
fix: Make Array types optional
Browse files Browse the repository at this point in the history
* FooArray creation conditioned behind #[soa_array]

* Addressed new Clippy warnings

* Bumped crate versions to 0.6.0

* Depend on updated soa-rs-derive version

* Documented the #[soa_array] attribute

* Patch-bumped the crate version for docs
  • Loading branch information
tim-harding authored May 9, 2024
1 parent 06e14aa commit bf113be
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 142 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "soa-rs"
version = "0.5.1"
version = "0.6.1"
edition = "2021"
license = "MIT"
description = "A Vec-like structure-of-arrays container"
Expand All @@ -16,7 +16,7 @@ debug = true
members = ["soa-rs-derive", "soa-rs-testing"]

[dependencies.soa-rs-derive]
version = "0.4.0"
version = "0.6.0"
path = "soa-rs-derive"

[dependencies.serde]
Expand Down
2 changes: 1 addition & 1 deletion soa-rs-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "soa-rs-derive"
version = "0.4.0"
version = "0.6.0"
edition = "2021"
license = "MIT"
description = "Proc macro derive for soa-rs"
Expand Down
179 changes: 92 additions & 87 deletions soa-rs-derive/src/fields.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
zst::{zst_struct, ZstKind},
SoaDerive,
SoaAttrs, SoaDerive,
};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
Expand All @@ -11,15 +11,19 @@ pub fn fields_struct(
vis: Visibility,
fields: Punctuated<Field, Comma>,
kind: FieldKind,
soa_derive: SoaDerive,
soa_attrs: SoaAttrs,
) -> Result<TokenStream, syn::Error> {
let SoaDerive {
r#ref: derive_ref,
ref_mut: derive_ref_mut,
slices: derive_slices,
slices_mut: derive_slices_mut,
array: derive_array,
} = soa_derive;
let SoaAttrs {
derive:
SoaDerive {
r#ref: derive_ref,
ref_mut: derive_ref_mut,
slices: derive_slices,
slices_mut: derive_slices_mut,
array: derive_array,
},
include_array,
} = soa_attrs;

let fields_len = fields.len();
let (vis_all, (ty_all, (ident_all, attrs_all))): (Vec<_>, (Vec<_>, (Vec<_>, Vec<_>))) = fields
Expand Down Expand Up @@ -204,91 +208,93 @@ pub fn fields_struct(
#vis struct #slices_mut<'a> #slices_mut_def
});

let array_def = define(&|ty| quote! { [#ty; N] });
let uninit_def = define(&|ty| quote! { [::std::mem::MaybeUninit<#ty>; K] });
out.append_all(quote! {
#derive_array
#[automatically_derived]
#vis struct #array<const N: usize> #array_def

#[automatically_derived]
impl<const N: usize> #array<N> {
#vis const fn from_array(array: [#ident; N]) -> Self {
let array = ::std::mem::ManuallyDrop::new(array);
let array = ::std::ptr::from_ref::<::std::mem::ManuallyDrop<[#ident; N]>>(&array);
let array = array.cast::<[#ident; N]>();
let array = unsafe { &*array };

struct Uninit<const K: usize> #uninit_def;

let mut uninit: Uninit<N> = Uninit {
#(
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#initializing-an-array-element-by-element
//
// TODO: Prefer when stablized:
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.uninit_array
#ident_all: unsafe { ::std::mem::MaybeUninit::uninit().assume_init() },
)*
};

let mut i = 0;
while i < N {
#(
let src = ::std::ptr::from_ref(&array[i].#ident_all);
unsafe {
uninit.#ident_all[i] = ::std::mem::MaybeUninit::new(src.read());
if include_array {
let array_def = define(&|ty| quote! { [#ty; N] });
let uninit_def = define(&|ty| quote! { [::std::mem::MaybeUninit<#ty>; K] });
out.append_all(quote! {
#derive_array
#[automatically_derived]
#vis struct #array<const N: usize> #array_def

#[automatically_derived]
impl<const N: usize> #array<N> {
#vis const fn from_array(array: [#ident; N]) -> Self {
let array = ::std::mem::ManuallyDrop::new(array);
let array = ::std::ptr::from_ref::<::std::mem::ManuallyDrop<[#ident; N]>>(&array);
let array = array.cast::<[#ident; N]>();
let array = unsafe { &*array };

struct Uninit<const K: usize> #uninit_def;

let mut uninit: Uninit<N> = Uninit {
#(
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#initializing-an-array-element-by-element
//
// TODO: Prefer when stablized:
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.uninit_array
#ident_all: unsafe { ::std::mem::MaybeUninit::uninit().assume_init() },
)*
};

let mut i = 0;
while i < N {
#(
let src = ::std::ptr::from_ref(&array[i].#ident_all);
unsafe {
uninit.#ident_all[i] = ::std::mem::MaybeUninit::new(src.read());
}
)*

i += 1;
}
)*

i += 1;
}

Self {
#(
// TODO: Prefer when stabilized:
// https://doc.rust-lang.org/std/primitive.array.html#method.transpose
#ident_all: unsafe {
::std::mem::transmute_copy(&::std::mem::ManuallyDrop::new(uninit.#ident_all))
},
)*
Self {
#(
// TODO: Prefer when stabilized:
// https://doc.rust-lang.org/std/primitive.array.html#method.transpose
#ident_all: unsafe {
::std::mem::transmute_copy(&::std::mem::ManuallyDrop::new(uninit.#ident_all))
},
)*
}
}
}
}

#[automatically_derived]
impl<const N: usize> ::soa_rs::AsSlice for #array<N> {
type Item = #ident;

fn as_slice(&self) -> ::soa_rs::SliceRef<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_slice().as_ptr().cast_mut();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceRef::from_slice(slice, N) }
#[automatically_derived]
impl<const N: usize> ::soa_rs::AsSlice for #array<N> {
type Item = #ident;

fn as_slice(&self) -> ::soa_rs::SliceRef<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_slice().as_ptr().cast_mut();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceRef::from_slice(slice, N) }
}
}
}

#[automatically_derived]
impl<const N: usize> ::soa_rs::AsMutSlice for #array<N> {
fn as_mut_slice(&mut self) -> ::soa_rs::SliceMut<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_mut_slice().as_mut_ptr();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceMut::from_slice(slice, N) }
#[automatically_derived]
impl<const N: usize> ::soa_rs::AsMutSlice for #array<N> {
fn as_mut_slice(&mut self) -> ::soa_rs::SliceMut<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_mut_slice().as_mut_ptr();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceMut::from_slice(slice, N) }
}
}
}
});
});
}

let indices = std::iter::repeat(()).enumerate().map(|(i, ())| i);
let offsets_len = fields_len - 1;
Expand Down Expand Up @@ -345,7 +351,6 @@ pub fn fields_struct(
type Deref = #deref;
type Ref<'a> = #item_ref<'a> where Self: 'a;
type RefMut<'a> = #item_ref_mut<'a> where Self: 'a;
type Array<const N: usize> = #array<N>;
type Slices<'a> = #slices<'a> where Self: 'a;
type SlicesMut<'a> = #slices_mut<'a> where Self: 'a;
}
Expand Down
106 changes: 65 additions & 41 deletions soa-rs-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields};
use zst::{zst_struct, ZstKind};

#[proc_macro_derive(Soars, attributes(align, soa_derive))]
#[proc_macro_derive(Soars, attributes(align, soa_derive, soa_array))]
pub fn soa(input: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(input);
let span = input.ident.span();
Expand All @@ -39,24 +39,22 @@ fn soa_inner(input: DeriveInput) -> Result<TokenStream2, SoarsError> {
generics: _,
} = input;

let mut soa_derive = SoaDeriveParse::new();
soa_derive.append(attrs)?;
let soa_derive = soa_derive.into_derive();
let attrs = SoaAttrs::new(attrs)?;
match data {
Data::Struct(strukt) => match strukt.fields {
Fields::Named(fields) => Ok(fields_struct(
ident,
vis,
fields.named,
FieldKind::Named,
soa_derive,
attrs,
)?),
Fields::Unnamed(fields) => Ok(fields_struct(
ident,
vis,
fields.unnamed,
FieldKind::Unnamed,
soa_derive,
attrs,
)?),
Fields::Unit => Ok(zst_struct(ident, vis, ZstKind::Unit)),
},
Expand All @@ -76,6 +74,34 @@ impl From<syn::Error> for SoarsError {
}
}

#[derive(Debug, Clone)]
struct SoaAttrs {
pub derive: SoaDerive,
pub include_array: bool,
}

impl SoaAttrs {
pub fn new(attributes: Vec<Attribute>) -> Result<Self, syn::Error> {
let mut derive_parse = SoaDeriveParse::new();
let mut include_array = false;
for attr in attributes {
let path = attr.path();
if path.is_ident("soa_derive") {
derive_parse.append(attr)?;
} else if path.is_ident("soa_array") {
include_array = true;
} else {
return Err(syn::Error::new_spanned(attr, "Unknown SOA attribute"));
}
}

Ok(Self {
derive: derive_parse.into_derive(),
include_array,
})
}
}

#[derive(Debug, Clone, Default)]
struct SoaDeriveParse {
r#ref: Vec<syn::Path>,
Expand Down Expand Up @@ -123,44 +149,42 @@ impl SoaDeriveParse {
}
}

pub fn append(&mut self, value: Vec<Attribute>) -> Result<(), syn::Error> {
for attr in value {
if attr.path().is_ident("soa_derive") {
let mut collected = vec![];
let mut mask = SoaDeriveMask::new();
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("include") {
mask = SoaDeriveMask::splat(false);
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, true).map_err(|_| {
meta.error(format!("unknown include specifier {:?}", meta.path))
})
})?;
} else if meta.path.is_ident("exclude") {
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, false).map_err(|_| {
meta.error(format!("unknown exclude specifier {:?}", meta.path))
})
})?;
} else {
collected.push(meta.path);
}
Ok(())
pub fn append(&mut self, attr: Attribute) -> Result<(), syn::Error> {
let mut collected = vec![];
let mut mask = SoaDeriveMask::new();
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("include") {
mask = SoaDeriveMask::splat(false);
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, true).map_err(|_| {
meta.error(format!("unknown include specifier {:?}", meta.path))
})
})?;

let to_extend = mask
.r#ref
.then_some(&mut self.r#ref)
.into_iter()
.chain(mask.ref_mut.then_some(&mut self.ref_mut).into_iter())
.chain(mask.slice.then_some(&mut self.slices).into_iter())
.chain(mask.slice_mut.then_some(&mut self.slices_mut).into_iter())
.chain(mask.array.then_some(&mut self.array).into_iter());
for set in to_extend {
set.extend(collected.iter().cloned());
}
} else if meta.path.is_ident("exclude") {
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, false).map_err(|_| {
meta.error(format!("unknown exclude specifier {:?}", meta.path))
})
})?;
} else {
collected.push(meta.path);
}
Ok(())
})?;

let to_extend = mask
.r#ref
.then_some(&mut self.r#ref)
.into_iter()
.chain(mask.ref_mut.then_some(&mut self.ref_mut))
.chain(mask.slice.then_some(&mut self.slices))
.chain(mask.slice_mut.then_some(&mut self.slices_mut))
.chain(mask.array.then_some(&mut self.array));

for set in to_extend {
set.extend(collected.iter().cloned());
}

Ok(())
}
}
Expand Down
1 change: 0 additions & 1 deletion soa-rs-derive/src/zst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub fn zst_struct(ident: Ident, vis: Visibility, kind: ZstKind) -> TokenStream {
type RefMut<'a> = #ident;
type Slices<'a> = #ident;
type SlicesMut<'a> = #ident;
type Array<const N: usize> = #array<N>;
}

#[automatically_derived]
Expand Down
Loading

0 comments on commit bf113be

Please sign in to comment.