Skip to content
122 changes: 13 additions & 109 deletions crates/bevy_render/macros/src/specializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ const SPECIALIZE_ALL_IDENT: &str = "all";
const KEY_ATTR_IDENT: &str = "key";
const KEY_DEFAULT_IDENT: &str = "default";

const BASE_DESCRIPTOR_ATTR_IDENT: &str = "base_descriptor";

enum SpecializeImplTargets {
All,
Specific(Vec<Path>),
Expand Down Expand Up @@ -87,7 +85,6 @@ struct FieldInfo {
ty: Type,
member: Member,
key: Key,
use_base_descriptor: bool,
}

impl FieldInfo {
Expand Down Expand Up @@ -117,15 +114,6 @@ impl FieldInfo {
parse_quote!(#ty: #specialize_path::Specializer<#target_path>)
}
}

fn get_base_descriptor_predicate(
&self,
specialize_path: &Path,
target_path: &Path,
) -> WherePredicate {
let ty = &self.ty;
parse_quote!(#ty: #specialize_path::GetBaseDescriptor<#target_path>)
}
}

fn get_field_info(
Expand All @@ -151,12 +139,8 @@ fn get_field_info(

let mut use_key_field = true;
let mut key = Key::Index(key_index);
let mut use_base_descriptor = false;
for attr in &field.attrs {
match &attr.meta {
Meta::Path(path) if path.is_ident(&BASE_DESCRIPTOR_ATTR_IDENT) => {
use_base_descriptor = true;
}
Meta::List(MetaList { path, tokens, .. }) if path.is_ident(&KEY_ATTR_IDENT) => {
let owned_tokens = tokens.clone().into();
let Ok(parsed_key) = syn::parse::<Key>(owned_tokens) else {
Expand Down Expand Up @@ -190,7 +174,6 @@ fn get_field_info(
ty: field_ty,
member: field_member,
key,
use_base_descriptor,
});
}

Expand Down Expand Up @@ -261,41 +244,18 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
})
.collect();

let base_descriptor_fields = field_info
.iter()
.filter(|field| field.use_base_descriptor)
.collect::<Vec<_>>();

if base_descriptor_fields.len() > 1 {
return syn::Error::new(
Span::call_site(),
"Too many #[base_descriptor] attributes found. It must be present on exactly one field",
)
.into_compile_error()
.into();
}

let base_descriptor_field = base_descriptor_fields.first().copied();

match targets {
SpecializeImplTargets::All => {
let specialize_impl = impl_specialize_all(
&specialize_path,
&ecs_path,
&ast,
&field_info,
&key_patterns,
&key_tuple_idents,
);
let get_base_descriptor_impl = base_descriptor_field
.map(|field_info| impl_get_base_descriptor_all(&specialize_path, &ast, field_info))
.unwrap_or_default();
[specialize_impl, get_base_descriptor_impl]
.into_iter()
.collect()
}
SpecializeImplTargets::Specific(targets) => {
let specialize_impls = targets.iter().map(|target| {
SpecializeImplTargets::All => impl_specialize_all(
&specialize_path,
&ecs_path,
&ast,
&field_info,
&key_patterns,
&key_tuple_idents,
),
SpecializeImplTargets::Specific(targets) => targets
.iter()
.map(|target| {
impl_specialize_specific(
&specialize_path,
&ecs_path,
Expand All @@ -305,14 +265,8 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
&key_patterns,
&key_tuple_idents,
)
});
let get_base_descriptor_impls = targets.iter().filter_map(|target| {
base_descriptor_field.map(|field_info| {
impl_get_base_descriptor_specific(&specialize_path, &ast, field_info, target)
})
});
specialize_impls.chain(get_base_descriptor_impls).collect()
}
})
.collect(),
}
}

Expand Down Expand Up @@ -406,56 +360,6 @@ fn impl_specialize_specific(
})
}

fn impl_get_base_descriptor_specific(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
target_path: &Path,
) -> TokenStream {
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote!(
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
))
}

fn impl_get_base_descriptor_all(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
) -> TokenStream {
let target_path = Path::from(format_ident!("T"));
let struct_name = &ast.ident;
let mut generics = ast.generics.clone();
generics.params.insert(
0,
parse_quote!(#target_path: #specialize_path::Specializable),
);

let where_clause = generics.make_where_clause();
where_clause.predicates.push(
base_descriptor_field_info.get_base_descriptor_predicate(specialize_path, &target_path),
);

let (_, type_generics, _) = ast.generics.split_for_impl();
let (impl_generics, _, where_clause) = &generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote! {
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
})
}

pub fn impl_specializer_key(input: TokenStream) -> TokenStream {
let bevy_render_path: Path = crate::bevy_render_path();
let specialize_path = {
Expand Down
3 changes: 3 additions & 0 deletions crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ use render_asset::{
extract_render_asset_bytes_per_frame, reset_render_asset_bytes_per_frame,
RenderAssetBytesPerFrame, RenderAssetBytesPerFrameLimiter,
};
use render_resource::init_empty_bind_group_layout;
use renderer::{RenderAdapter, RenderDevice, RenderQueue};
use settings::RenderResources;
use sync_world::{
Expand Down Expand Up @@ -467,6 +468,8 @@ impl Plugin for RenderPlugin {
Render,
reset_render_asset_bytes_per_frame.in_set(RenderSystems::Cleanup),
);

render_app.add_systems(RenderStartup, init_empty_bind_group_layout);
}

app.register_type::<alpha::AlphaMode>()
Expand Down
21 changes: 19 additions & 2 deletions crates/bevy_render/src/render_resource/bind_group_layout.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::define_atomic_id;
use crate::WgpuWrapper;
use crate::{define_atomic_id, renderer::RenderDevice, WgpuWrapper};
use bevy_ecs::system::Res;
use bevy_platform::sync::OnceLock;
use core::ops::Deref;

define_atomic_id!(BindGroupLayoutId);
Expand Down Expand Up @@ -62,3 +63,19 @@ impl Deref for BindGroupLayout {
&self.value
}
}

static EMPTY_BIND_GROUP_LAYOUT: OnceLock<BindGroupLayout> = OnceLock::new();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting this here (not blocking) gfx-rs/wgpu#7862


pub(crate) fn init_empty_bind_group_layout(render_device: Res<RenderDevice>) {
let layout = render_device.create_bind_group_layout(Some("empty_bind_group_layout"), &[]);
EMPTY_BIND_GROUP_LAYOUT
.set(layout)
.expect("init_empty_bind_group_layout was called more than once");
}

pub fn empty_bind_group_layout() -> BindGroupLayout {
EMPTY_BIND_GROUP_LAYOUT
.get()
.expect("init_empty_bind_group_layout was not called")
.clone()
}
32 changes: 31 additions & 1 deletion crates/bevy_render/src/render_resource/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::ShaderDefVal;
use super::{empty_bind_group_layout, ShaderDefVal};
use crate::mesh::VertexBufferLayout;
use crate::WgpuWrapper;
use crate::{
Expand All @@ -7,7 +7,9 @@ use crate::{
};
use alloc::borrow::Cow;
use bevy_asset::Handle;
use core::iter;
use core::ops::Deref;
use thiserror::Error;
use wgpu::{
ColorTargetState, DepthStencilState, MultisampleState, PrimitiveState, PushConstantRange,
};
Expand Down Expand Up @@ -112,6 +114,20 @@ pub struct RenderPipelineDescriptor {
pub zero_initialize_workgroup_memory: bool,
}

#[derive(Copy, Clone, Debug, Error)]
#[error("RenderPipelineDescriptor has no FragmentState configured")]
pub struct NoFragmentStateError;

impl RenderPipelineDescriptor {
pub fn fragment_mut(&mut self) -> Result<&mut FragmentState, NoFragmentStateError> {
self.fragment.as_mut().ok_or(NoFragmentStateError)
}

pub fn set_layout(&mut self, index: usize, layout: BindGroupLayout) {
filling_set_at(&mut self.layout, index, empty_bind_group_layout(), layout);
}
}

#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct VertexState {
/// The compiled shader module for this stage.
Expand All @@ -137,6 +153,12 @@ pub struct FragmentState {
pub targets: Vec<Option<ColorTargetState>>,
}

impl FragmentState {
pub fn set_target(&mut self, index: usize, target: ColorTargetState) {
filling_set_at(&mut self.targets, index, None, Some(target));
}
}

/// Describes a compute pipeline.
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct ComputePipelineDescriptor {
Expand All @@ -153,3 +175,11 @@ pub struct ComputePipelineDescriptor {
/// If this is false, reading from workgroup variables before writing to them will result in garbage values.
pub zero_initialize_workgroup_memory: bool,
}

// utility function to set a value at the specified index, extending with
// a filler value if the index is out of bounds.
fn filling_set_at<T: Clone>(vec: &mut Vec<T>, index: usize, filler: T, value: T) {
let num_to_fill = (index + 1).saturating_sub(vec.len());
vec.extend(iter::repeat_n(filler, num_to_fill));
vec[index] = value;
}
Loading
Loading