Skip to content

Commit

Permalink
Bind group deduplication (#2623)
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald authored Apr 25, 2022
1 parent c226a10 commit bc850d2
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 59 deletions.
55 changes: 38 additions & 17 deletions wgpu-core/src/command/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ invalidations or index format changes.
use crate::{
binding_model::buffer_binding_type_alignment,
command::{
BasePass, DrawError, MapPassErr, PassErrorScope, RenderCommand, RenderCommandError,
StateChange,
BasePass, BindGroupStateChange, DrawError, MapPassErr, PassErrorScope, RenderCommand,
RenderCommandError, StateChange,
},
conv,
device::{
Expand Down Expand Up @@ -86,6 +86,12 @@ pub struct RenderBundleEncoder {
parent_id: id::DeviceId,
pub(crate) context: RenderPassContext,
pub(crate) is_ds_read_only: bool,

// Resource binding dedupe state.
#[cfg_attr(feature = "serial-pass", serde(skip))]
current_bind_groups: BindGroupStateChange,
#[cfg_attr(feature = "serial-pass", serde(skip))]
current_pipeline: StateChange<id::RenderPipelineId>,
}

impl RenderBundleEncoder {
Expand Down Expand Up @@ -126,6 +132,9 @@ impl RenderBundleEncoder {
}
None => false,
},

current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(),
})
}

Expand All @@ -143,6 +152,9 @@ impl RenderBundleEncoder {
multiview: None,
},
is_ds_read_only: false,

current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(),
}
}

Expand Down Expand Up @@ -180,7 +192,7 @@ impl RenderBundleEncoder {
raw_dynamic_offsets: Vec::new(),
flat_dynamic_offsets: Vec::new(),
used_bind_groups: 0,
pipeline: StateChange::new(),
pipeline: None,
};
let mut commands = Vec::new();
let mut base = self.base.as_ref();
Expand Down Expand Up @@ -252,9 +264,8 @@ impl RenderBundleEncoder {
}
RenderCommand::SetPipeline(pipeline_id) => {
let scope = PassErrorScope::SetPipelineRender(pipeline_id);
if state.pipeline.set_and_check_redundant(pipeline_id) {
continue;
}

state.pipeline = Some(pipeline_id);

let pipeline = state
.trackers
Expand Down Expand Up @@ -370,7 +381,7 @@ impl RenderBundleEncoder {
let scope = PassErrorScope::Draw {
indexed: false,
indirect: false,
pipeline: state.pipeline.last_state,
pipeline: state.pipeline,
};
let vertex_limits = state.vertex_limits();
let last_vertex = first_vertex + vertex_count;
Expand Down Expand Up @@ -405,7 +416,7 @@ impl RenderBundleEncoder {
let scope = PassErrorScope::Draw {
indexed: true,
indirect: false,
pipeline: state.pipeline.last_state,
pipeline: state.pipeline,
};
//TODO: validate that base_vertex + max_index() is within the provided range
let vertex_limits = state.vertex_limits();
Expand Down Expand Up @@ -441,7 +452,7 @@ impl RenderBundleEncoder {
let scope = PassErrorScope::Draw {
indexed: false,
indirect: true,
pipeline: state.pipeline.last_state,
pipeline: state.pipeline,
};
device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
Expand Down Expand Up @@ -474,7 +485,7 @@ impl RenderBundleEncoder {
let scope = PassErrorScope::Draw {
indexed: true,
indirect: true,
pipeline: state.pipeline.last_state,
pipeline: state.pipeline,
};
device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
Expand Down Expand Up @@ -990,7 +1001,7 @@ struct State {
raw_dynamic_offsets: Vec<wgt::DynamicOffset>,
flat_dynamic_offsets: Vec<wgt::DynamicOffset>,
used_bind_groups: usize,
pipeline: StateChange<id::RenderPipelineId>,
pipeline: Option<id::RenderPipelineId>,
}

impl State {
Expand Down Expand Up @@ -1222,24 +1233,34 @@ pub mod bundle_ffi {
offsets: *const DynamicOffset,
offset_length: usize,
) {
let redundant = bundle.current_bind_groups.set_and_check_redundant(
bind_group_id,
index,
&mut bundle.base.dynamic_offsets,
offsets,
offset_length,
);

if redundant {
return;
}

bundle.base.commands.push(RenderCommand::SetBindGroup {
index: index.try_into().unwrap(),
num_dynamic_offsets: offset_length.try_into().unwrap(),
bind_group_id,
});
if offset_length != 0 {
bundle
.base
.dynamic_offsets
.extend_from_slice(slice::from_raw_parts(offsets, offset_length));
}
}

#[no_mangle]
pub extern "C" fn wgpu_render_bundle_set_pipeline(
bundle: &mut RenderBundleEncoder,
pipeline_id: id::RenderPipelineId,
) {
if bundle.current_pipeline.set_and_check_redundant(pipeline_id) {
return;
}

bundle
.base
.commands
Expand Down
48 changes: 33 additions & 15 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::{
bind::Binder,
end_pipeline_statistics_query,
memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
BasePass, BasePassRef, CommandBuffer, CommandEncoderError, CommandEncoderStatus,
MapPassErr, PassErrorScope, QueryUseError, StateChange,
BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
},
device::MissingDownlevelFlags,
error::{ErrorFormatter, PrettyError},
Expand Down Expand Up @@ -76,13 +76,22 @@ pub enum ComputeCommand {
pub struct ComputePass {
base: BasePass<ComputeCommand>,
parent_id: id::CommandEncoderId,

// Resource binding dedupe state.
#[cfg_attr(feature = "serial-pass", serde(skip))]
current_bind_groups: BindGroupStateChange,
#[cfg_attr(feature = "serial-pass", serde(skip))]
current_pipeline: StateChange<id::ComputePipelineId>,
}

impl ComputePass {
pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
Self {
base: BasePass::new(&desc.label),
parent_id,

current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(),
}
}

Expand Down Expand Up @@ -222,7 +231,7 @@ where
#[derive(Debug)]
struct State {
binder: Binder,
pipeline: StateChange<id::ComputePipelineId>,
pipeline: Option<id::ComputePipelineId>,
trackers: StatefulTrackerSubset,
debug_scope_depth: u32,
}
Expand All @@ -236,7 +245,7 @@ impl State {
index: bind_mask.trailing_zeros(),
});
}
if self.pipeline.is_unset() {
if self.pipeline.is_none() {
return Err(DispatchError::MissingPipeline);
}
self.binder.check_late_buffer_bindings()?;
Expand Down Expand Up @@ -325,7 +334,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

let mut state = State {
binder: Binder::new(),
pipeline: StateChange::new(),
pipeline: None,
trackers: StatefulTrackerSubset::new(A::VARIANT),
debug_scope_depth: 0,
};
Expand Down Expand Up @@ -420,9 +429,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
ComputeCommand::SetPipeline(pipeline_id) => {
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);

if state.pipeline.set_and_check_redundant(pipeline_id) {
continue;
}
state.pipeline = Some(pipeline_id);

let pipeline = cmd_buf
.trackers
Expand Down Expand Up @@ -524,7 +531,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
ComputeCommand::Dispatch(groups) => {
let scope = PassErrorScope::Dispatch {
indirect: false,
pipeline: state.pipeline.last_state,
pipeline: state.pipeline,
};

fixup_discarded_surfaces(
Expand Down Expand Up @@ -568,7 +575,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
ComputeCommand::DispatchIndirect { buffer_id, offset } => {
let scope = PassErrorScope::Dispatch {
indirect: true,
pipeline: state.pipeline.last_state,
pipeline: state.pipeline,
};

state.is_ready().map_pass_err(scope)?;
Expand Down Expand Up @@ -750,23 +757,34 @@ pub mod compute_ffi {
offsets: *const DynamicOffset,
offset_length: usize,
) {
let redundant = pass.current_bind_groups.set_and_check_redundant(
bind_group_id,
index,
&mut pass.base.dynamic_offsets,
offsets,
offset_length,
);

if redundant {
return;
}

pass.base.commands.push(ComputeCommand::SetBindGroup {
index: index.try_into().unwrap(),
num_dynamic_offsets: offset_length.try_into().unwrap(),
bind_group_id,
});
if offset_length != 0 {
pass.base
.dynamic_offsets
.extend_from_slice(slice::from_raw_parts(offsets, offset_length));
}
}

#[no_mangle]
pub extern "C" fn wgpu_compute_pass_set_pipeline(
pass: &mut ComputePass,
pipeline_id: id::ComputePipelineId,
) {
if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
return;
}

pass.base
.commands
.push(ComputeCommand::SetPipeline(pipeline_id));
Expand Down
61 changes: 57 additions & 4 deletions wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ mod query;
mod render;
mod transfer;

use std::slice;

pub(crate) use self::clear::clear_texture_no_device;
pub use self::{
bundle::*, clear::ClearError, compute::*, draw::*, query::*, render::*, transfer::*,
Expand Down Expand Up @@ -405,7 +407,7 @@ where
}
}

#[derive(Debug)]
#[derive(Debug, Copy, Clone)]
struct StateChange<T> {
last_state: Option<T>,
}
Expand All @@ -419,14 +421,65 @@ impl<T: Copy + PartialEq> StateChange<T> {
self.last_state = Some(new_state);
already_set
}
fn is_unset(&self) -> bool {
self.last_state.is_none()
}
fn reset(&mut self) {
self.last_state = None;
}
}

impl<T: Copy + PartialEq> Default for StateChange<T> {
fn default() -> Self {
Self::new()
}
}

#[derive(Debug)]
struct BindGroupStateChange {
last_states: [StateChange<id::BindGroupId>; hal::MAX_BIND_GROUPS],
}

impl BindGroupStateChange {
fn new() -> Self {
Self {
last_states: [StateChange::new(); hal::MAX_BIND_GROUPS],
}
}

unsafe fn set_and_check_redundant(
&mut self,
bind_group_id: id::BindGroupId,
index: u32,
dynamic_offsets: &mut Vec<u32>,
offsets: *const wgt::DynamicOffset,
offset_length: usize,
) -> bool {
// For now never deduplicate bind groups with dynamic offsets.
if offset_length == 0 {
// If this get returns None, that means we're well over the limit, so let the call through to get a proper error
if let Some(current_bind_group) = self.last_states.get_mut(index as usize) {
// Bail out if we're binding the same bind group.
if current_bind_group.set_and_check_redundant(bind_group_id) {
return true;
}
}
} else {
// We intentionally remove the memory of this bind group if we have dynamic offsets,
// such that if you try to bind this bind group later with _no_ dynamic offsets it
// tries to bind it again and gives a proper validation error.
if let Some(current_bind_group) = self.last_states.get_mut(index as usize) {
current_bind_group.reset();
}
dynamic_offsets.extend_from_slice(slice::from_raw_parts(offsets, offset_length));
}
false
}
}

impl Default for BindGroupStateChange {
fn default() -> Self {
Self::new()
}
}

trait MapPassErr<T, O> {
fn map_pass_err(self, scope: PassErrorScope) -> Result<T, O>;
}
Expand Down
Loading

0 comments on commit bc850d2

Please sign in to comment.