Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional array #14

Merged
merged 6 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "soa-rs"
version = "0.5.1"
version = "0.6.1"
edition = "2021"
license = "MIT"
description = "A Vec-like structure-of-arrays container"
Expand All @@ -16,7 +16,7 @@ debug = true
members = ["soa-rs-derive", "soa-rs-testing"]

[dependencies.soa-rs-derive]
version = "0.4.0"
version = "0.6.0"
path = "soa-rs-derive"

[dependencies.serde]
Expand Down
2 changes: 1 addition & 1 deletion soa-rs-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "soa-rs-derive"
version = "0.4.0"
version = "0.6.0"
edition = "2021"
license = "MIT"
description = "Proc macro derive for soa-rs"
Expand Down
179 changes: 92 additions & 87 deletions soa-rs-derive/src/fields.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
zst::{zst_struct, ZstKind},
SoaDerive,
SoaAttrs, SoaDerive,
};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
Expand All @@ -11,15 +11,19 @@ pub fn fields_struct(
vis: Visibility,
fields: Punctuated<Field, Comma>,
kind: FieldKind,
soa_derive: SoaDerive,
soa_attrs: SoaAttrs,
) -> Result<TokenStream, syn::Error> {
let SoaDerive {
r#ref: derive_ref,
ref_mut: derive_ref_mut,
slices: derive_slices,
slices_mut: derive_slices_mut,
array: derive_array,
} = soa_derive;
let SoaAttrs {
derive:
SoaDerive {
r#ref: derive_ref,
ref_mut: derive_ref_mut,
slices: derive_slices,
slices_mut: derive_slices_mut,
array: derive_array,
},
include_array,
} = soa_attrs;

let fields_len = fields.len();
let (vis_all, (ty_all, (ident_all, attrs_all))): (Vec<_>, (Vec<_>, (Vec<_>, Vec<_>))) = fields
Expand Down Expand Up @@ -204,91 +208,93 @@ pub fn fields_struct(
#vis struct #slices_mut<'a> #slices_mut_def
});

let array_def = define(&|ty| quote! { [#ty; N] });
let uninit_def = define(&|ty| quote! { [::std::mem::MaybeUninit<#ty>; K] });
out.append_all(quote! {
#derive_array
#[automatically_derived]
#vis struct #array<const N: usize> #array_def

#[automatically_derived]
impl<const N: usize> #array<N> {
#vis const fn from_array(array: [#ident; N]) -> Self {
let array = ::std::mem::ManuallyDrop::new(array);
let array = ::std::ptr::from_ref::<::std::mem::ManuallyDrop<[#ident; N]>>(&array);
let array = array.cast::<[#ident; N]>();
let array = unsafe { &*array };

struct Uninit<const K: usize> #uninit_def;

let mut uninit: Uninit<N> = Uninit {
#(
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#initializing-an-array-element-by-element
//
// TODO: Prefer when stablized:
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.uninit_array
#ident_all: unsafe { ::std::mem::MaybeUninit::uninit().assume_init() },
)*
};

let mut i = 0;
while i < N {
#(
let src = ::std::ptr::from_ref(&array[i].#ident_all);
unsafe {
uninit.#ident_all[i] = ::std::mem::MaybeUninit::new(src.read());
if include_array {
let array_def = define(&|ty| quote! { [#ty; N] });
let uninit_def = define(&|ty| quote! { [::std::mem::MaybeUninit<#ty>; K] });
out.append_all(quote! {
#derive_array
#[automatically_derived]
#vis struct #array<const N: usize> #array_def

#[automatically_derived]
impl<const N: usize> #array<N> {
#vis const fn from_array(array: [#ident; N]) -> Self {
let array = ::std::mem::ManuallyDrop::new(array);
let array = ::std::ptr::from_ref::<::std::mem::ManuallyDrop<[#ident; N]>>(&array);
let array = array.cast::<[#ident; N]>();
let array = unsafe { &*array };

struct Uninit<const K: usize> #uninit_def;

let mut uninit: Uninit<N> = Uninit {
#(
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#initializing-an-array-element-by-element
//
// TODO: Prefer when stablized:
// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.uninit_array
#ident_all: unsafe { ::std::mem::MaybeUninit::uninit().assume_init() },
)*
};

let mut i = 0;
while i < N {
#(
let src = ::std::ptr::from_ref(&array[i].#ident_all);
unsafe {
uninit.#ident_all[i] = ::std::mem::MaybeUninit::new(src.read());
}
)*

i += 1;
}
)*

i += 1;
}

Self {
#(
// TODO: Prefer when stabilized:
// https://doc.rust-lang.org/std/primitive.array.html#method.transpose
#ident_all: unsafe {
::std::mem::transmute_copy(&::std::mem::ManuallyDrop::new(uninit.#ident_all))
},
)*
Self {
#(
// TODO: Prefer when stabilized:
// https://doc.rust-lang.org/std/primitive.array.html#method.transpose
#ident_all: unsafe {
::std::mem::transmute_copy(&::std::mem::ManuallyDrop::new(uninit.#ident_all))
},
)*
}
}
}
}

#[automatically_derived]
impl<const N: usize> ::soa_rs::AsSlice for #array<N> {
type Item = #ident;

fn as_slice(&self) -> ::soa_rs::SliceRef<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_slice().as_ptr().cast_mut();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceRef::from_slice(slice, N) }
#[automatically_derived]
impl<const N: usize> ::soa_rs::AsSlice for #array<N> {
type Item = #ident;

fn as_slice(&self) -> ::soa_rs::SliceRef<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_slice().as_ptr().cast_mut();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceRef::from_slice(slice, N) }
}
}
}

#[automatically_derived]
impl<const N: usize> ::soa_rs::AsMutSlice for #array<N> {
fn as_mut_slice(&mut self) -> ::soa_rs::SliceMut<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_mut_slice().as_mut_ptr();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceMut::from_slice(slice, N) }
#[automatically_derived]
impl<const N: usize> ::soa_rs::AsMutSlice for #array<N> {
fn as_mut_slice(&mut self) -> ::soa_rs::SliceMut<'_, Self::Item> {
let raw = #raw {
#(
#ident_all: {
let ptr = self.#ident_all.as_mut_slice().as_mut_ptr();
unsafe { ::std::ptr::NonNull::new_unchecked(ptr) }
},
)*
};
let slice = ::soa_rs::Slice::with_raw(raw);
unsafe { ::soa_rs::SliceMut::from_slice(slice, N) }
}
}
}
});
});
}

let indices = std::iter::repeat(()).enumerate().map(|(i, ())| i);
let offsets_len = fields_len - 1;
Expand Down Expand Up @@ -345,7 +351,6 @@ pub fn fields_struct(
type Deref = #deref;
type Ref<'a> = #item_ref<'a> where Self: 'a;
type RefMut<'a> = #item_ref_mut<'a> where Self: 'a;
type Array<const N: usize> = #array<N>;
type Slices<'a> = #slices<'a> where Self: 'a;
type SlicesMut<'a> = #slices_mut<'a> where Self: 'a;
}
Expand Down
106 changes: 65 additions & 41 deletions soa-rs-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields};
use zst::{zst_struct, ZstKind};

#[proc_macro_derive(Soars, attributes(align, soa_derive))]
#[proc_macro_derive(Soars, attributes(align, soa_derive, soa_array))]
pub fn soa(input: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(input);
let span = input.ident.span();
Expand All @@ -39,24 +39,22 @@ fn soa_inner(input: DeriveInput) -> Result<TokenStream2, SoarsError> {
generics: _,
} = input;

let mut soa_derive = SoaDeriveParse::new();
soa_derive.append(attrs)?;
let soa_derive = soa_derive.into_derive();
let attrs = SoaAttrs::new(attrs)?;
match data {
Data::Struct(strukt) => match strukt.fields {
Fields::Named(fields) => Ok(fields_struct(
ident,
vis,
fields.named,
FieldKind::Named,
soa_derive,
attrs,
)?),
Fields::Unnamed(fields) => Ok(fields_struct(
ident,
vis,
fields.unnamed,
FieldKind::Unnamed,
soa_derive,
attrs,
)?),
Fields::Unit => Ok(zst_struct(ident, vis, ZstKind::Unit)),
},
Expand All @@ -76,6 +74,34 @@ impl From<syn::Error> for SoarsError {
}
}

#[derive(Debug, Clone)]
struct SoaAttrs {
pub derive: SoaDerive,
pub include_array: bool,
}

impl SoaAttrs {
pub fn new(attributes: Vec<Attribute>) -> Result<Self, syn::Error> {
let mut derive_parse = SoaDeriveParse::new();
let mut include_array = false;
for attr in attributes {
let path = attr.path();
if path.is_ident("soa_derive") {
derive_parse.append(attr)?;
} else if path.is_ident("soa_array") {
include_array = true;
} else {
return Err(syn::Error::new_spanned(attr, "Unknown SOA attribute"));
}
}

Ok(Self {
derive: derive_parse.into_derive(),
include_array,
})
}
}

#[derive(Debug, Clone, Default)]
struct SoaDeriveParse {
r#ref: Vec<syn::Path>,
Expand Down Expand Up @@ -123,44 +149,42 @@ impl SoaDeriveParse {
}
}

pub fn append(&mut self, value: Vec<Attribute>) -> Result<(), syn::Error> {
for attr in value {
if attr.path().is_ident("soa_derive") {
let mut collected = vec![];
let mut mask = SoaDeriveMask::new();
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("include") {
mask = SoaDeriveMask::splat(false);
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, true).map_err(|_| {
meta.error(format!("unknown include specifier {:?}", meta.path))
})
})?;
} else if meta.path.is_ident("exclude") {
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, false).map_err(|_| {
meta.error(format!("unknown exclude specifier {:?}", meta.path))
})
})?;
} else {
collected.push(meta.path);
}
Ok(())
pub fn append(&mut self, attr: Attribute) -> Result<(), syn::Error> {
let mut collected = vec![];
let mut mask = SoaDeriveMask::new();
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("include") {
mask = SoaDeriveMask::splat(false);
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, true).map_err(|_| {
meta.error(format!("unknown include specifier {:?}", meta.path))
})
})?;

let to_extend = mask
.r#ref
.then_some(&mut self.r#ref)
.into_iter()
.chain(mask.ref_mut.then_some(&mut self.ref_mut).into_iter())
.chain(mask.slice.then_some(&mut self.slices).into_iter())
.chain(mask.slice_mut.then_some(&mut self.slices_mut).into_iter())
.chain(mask.array.then_some(&mut self.array).into_iter());
for set in to_extend {
set.extend(collected.iter().cloned());
}
} else if meta.path.is_ident("exclude") {
meta.parse_nested_meta(|meta| {
mask.set_by_path(&meta.path, false).map_err(|_| {
meta.error(format!("unknown exclude specifier {:?}", meta.path))
})
})?;
} else {
collected.push(meta.path);
}
Ok(())
})?;

let to_extend = mask
.r#ref
.then_some(&mut self.r#ref)
.into_iter()
.chain(mask.ref_mut.then_some(&mut self.ref_mut))
.chain(mask.slice.then_some(&mut self.slices))
.chain(mask.slice_mut.then_some(&mut self.slices_mut))
.chain(mask.array.then_some(&mut self.array));

for set in to_extend {
set.extend(collected.iter().cloned());
}

Ok(())
}
}
Expand Down
1 change: 0 additions & 1 deletion soa-rs-derive/src/zst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub fn zst_struct(ident: Ident, vis: Visibility, kind: ZstKind) -> TokenStream {
type RefMut<'a> = #ident;
type Slices<'a> = #ident;
type SlicesMut<'a> = #ident;
type Array<const N: usize> = #array<N>;
}

#[automatically_derived]
Expand Down
Loading