Skip to content

Commit

Permalink
ensure safety of indirect dispatch
Browse files Browse the repository at this point in the history
by injecting a compute shader that validates the content of the indirect buffer
  • Loading branch information
teoxoy committed May 17, 2024
1 parent 4902e47 commit 018b23b
Show file tree
Hide file tree
Showing 10 changed files with 508 additions and 2 deletions.
116 changes: 116 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters};

const SHADER_SRC: &str = "
@group(0) @binding(0)
var<storage, read_write> out: u32;
@compute @workgroup_size(1)
fn main() {
out = 1u;
}
";

#[gpu_test]
static CHECK_NUM_WORKGROUPS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits())
.run_async(|ctx| async move {
let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
compilation_options: Default::default(),
cache: None,
});

let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 4,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 4,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});

let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::INDIRECT
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::UNIFORM,
mapped_at_creation: false,
});

let max = ctx.adapter.limits().max_compute_workgroups_per_dimension;
ctx.queue
.write_buffer(&indirect_buffer, 0, bytemuck::bytes_of(&[max + 1, 1, 1]));

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});

let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());

{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, 0);
}

encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 4);

ctx.queue.submit(Some(encoder.finish()));

readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});

ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let view = readback_buffer.slice(..).get_mapped_range();
// Make sure the dispatch was discarded
assert!(view.iter().all(|v| *v == 0));

// Test that unsetting the bind group works properly
{
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
compute_pass.set_pipeline(&pipeline);
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, 0);
}
let _ = encoder.finish();
let error = pollster::block_on(ctx.device.pop_error_scope());
assert!(error.map_or(false, |error| format!("{error}")
.contains("Expected bind group is missing")));
}
});
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod clear_texture;
mod compute_pass_resource_ownership;
mod create_surface_error;
mod device;
mod dispatch_workgroups_indirect;
mod encoder;
mod external_texture;
mod float32_filterable;
Expand Down
14 changes: 13 additions & 1 deletion wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod compat {
diff.push(format!("Expected {expected_bgl_type} bind group layout, got {assigned_bgl_type}"))
}
} else {
diff.push("Assigned bind group layout not found (internal error)".to_owned());
diff.push("Expected bind group is missing".to_owned());
}
} else {
diff.push("Expected bind group layout not found (internal error)".to_owned());
Expand Down Expand Up @@ -191,6 +191,10 @@ mod compat {
self.make_range(index)
}

pub fn unassign(&mut self, index: usize) {
self.entries[index].assigned = None;
}

pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
self.entries
.iter()
Expand Down Expand Up @@ -358,6 +362,14 @@ impl<A: HalApi> Binder<A> {
&self.payloads[bind_range]
}

pub(super) fn unassign_group(&mut self, index: usize) {
log::trace!("\tBinding [{}] = null", index);

self.payloads[index].reset();

self.manager.unassign(index);
}

pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup<A>>> + '_ {
let payloads = &self.payloads;
self.manager
Expand Down
207 changes: 206 additions & 1 deletion wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ pub enum ComputePassErrorInner {
MissingFeatures(#[from] MissingFeatures),
#[error(transparent)]
MissingDownlevelFlags(#[from] MissingDownlevelFlags),
#[error(transparent)]
IndirectValidation(#[from] ComputePassIndirectValidationError),
}

#[derive(Clone, Debug, Error)]
pub enum ComputePassIndirectValidationError {
#[error(transparent)]
ValidationPipeline(
#[from] crate::pipeline::CreateDispatchWorkgroupsIndirectValidationPipelineError,
),
#[error(transparent)]
Buffer(#[from] crate::resource::CreateBufferError),
#[error(transparent)]
BindGroup(#[from] crate::binding_model::CreateBindGroupError),
}

impl PrettyError for ComputePassErrorInner {
Expand Down Expand Up @@ -283,9 +297,26 @@ impl Global {
&self,
pass: &ComputePass<A>,
) -> Result<(), ComputePassError> {
let mut base = pass.base.as_ref();

let new_commands = base
.commands
.iter()
.any(|cmd| matches!(cmd, ArcComputeCommand::DispatchIndirect { .. }))
.then(|| {
self.command_encoder_inject_dispatch_workgroups_indirect_validation(
pass.parent_id,
base.commands,
)
})
.transpose()?;
if let Some(new_commands) = new_commands.as_ref() {
base.commands = new_commands;
}

self.command_encoder_run_compute_pass_impl(
pass.parent_id,
pass.base.as_ref(),
base,
pass.timestamp_writes.as_ref(),
)
}
Expand Down Expand Up @@ -515,6 +546,20 @@ impl Global {
}
}
}
ArcComputeCommand::UnsetBindGroup { index } => {
let scope = PassErrorScope::UnsetBindGroup(*index);

let max_bind_groups = cmd_buf.limits.max_bind_groups;
if *index >= max_bind_groups {
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
index: *index,
max: max_bind_groups,
})
.map_pass_err(scope);
}

state.binder.unassign_group(*index as usize);
}
ArcComputeCommand::SetPipeline(pipeline) => {
let pipeline_id = pipeline.as_info().id();
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
Expand Down Expand Up @@ -823,6 +868,166 @@ impl Global {

Ok(())
}

fn command_encoder_inject_dispatch_workgroups_indirect_validation<A: HalApi>(
&self,
encoder_id: id::CommandEncoderId,
commands: &[ArcComputeCommand<A>],
) -> Result<Vec<ArcComputeCommand<A>>, ComputePassError> {
profiling::scope!("CommandEncoder::inject_dispatch_workgroups_indirect_validation");
let scope = PassErrorScope::Pass(encoder_id);

let hub = A::hub(self);

let cmd_buf: Arc<CommandBuffer<A>> =
CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?;
let device = &cmd_buf.device;
if !device.is_valid() {
return Err(ComputePassErrorInner::InvalidDevice(
cmd_buf.device.as_info().id(),
))
.map_pass_err(scope);
}
let device_id = device.as_info().id();

let mut new_commands = Vec::with_capacity(commands.len());
let mut current_pipeline = None;
let mut current_bind_group_0 = None;

for command in commands {
match command {
ArcComputeCommand::SetBindGroup {
index: 0,
num_dynamic_offsets,
bind_group,
} => {
current_bind_group_0 = Some((*num_dynamic_offsets, bind_group.clone()));
new_commands.push(command.clone());
}
ArcComputeCommand::SetPipeline(pipeline) => {
current_pipeline = Some(pipeline.clone());
new_commands.push(command.clone());
}
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
// if there is no pipeline set, don't inject the validation commands as we will error anyway
if let Some(original_pipeline) = current_pipeline.clone() {
let (validation_pipeline_id, validation_bgl_id) = self
.device_get_or_create_dispatch_workgroups_indirect_validation_pipeline::<A>(
device_id,
)
.map_err(ComputePassIndirectValidationError::ValidationPipeline)
.map_pass_err(scope)?;

let (dst_buffer_id, error) = self.device_create_buffer::<A>(
device_id,
&crate::resource::BufferDescriptor {
label: None,
size: 4 * 3,
usage: wgt::BufferUsages::INDIRECT | wgt::BufferUsages::STORAGE,
mapped_at_creation: false,
},
None,
);
if let Some(error) = error {
return Err(ComputePassIndirectValidationError::Buffer(error))
.map_pass_err(scope)?;
}

let (bind_group_id, error) = self.device_create_bind_group::<A>(
device_id,
&crate::binding_model::BindGroupDescriptor {
label: None,
layout: validation_bgl_id,
entries: std::borrow::Cow::Borrowed(&[
crate::binding_model::BindGroupEntry {
binding: 0,
resource: crate::binding_model::BindingResource::Buffer(
crate::binding_model::BufferBinding {
buffer_id: buffer.as_info().id(),
offset: *offset,
size: Some(
std::num::NonZeroU64::new(4 * 3).unwrap(),
),
},
),
},
crate::binding_model::BindGroupEntry {
binding: 1,
resource: crate::binding_model::BindingResource::Buffer(
crate::binding_model::BufferBinding {
buffer_id: dst_buffer_id,
offset: 0,
size: Some(
std::num::NonZeroU64::new(4 * 3).unwrap(),
),
},
),
},
]),
},
None,
);
if let Some(error) = error {
return Err(ComputePassIndirectValidationError::BindGroup(error))
.map_pass_err(scope)?;
}

let validation_pipeline = hub
.compute_pipelines
.read()
.get_owned(validation_pipeline_id)
.map_err(|_| {
ComputePassErrorInner::InvalidPipeline(validation_pipeline_id)
})
.map_pass_err(scope)?;

let bind_group = hub
.bind_groups
.read()
.get_owned(bind_group_id)
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(0))
.map_pass_err(scope)?;

let dst_buffer = hub
.buffers
.read()
.get_owned(dst_buffer_id)
.map_err(|_| ComputePassErrorInner::InvalidBuffer(dst_buffer_id))
.map_pass_err(scope)?;

new_commands.push(ArcComputeCommand::SetPipeline(validation_pipeline));
new_commands.push(ArcComputeCommand::SetBindGroup {
index: 0,
num_dynamic_offsets: 0,
bind_group,
});
new_commands.push(ArcComputeCommand::Dispatch([1, 1, 1]));

new_commands.push(ArcComputeCommand::SetPipeline(original_pipeline));
if let Some((num_dynamic_offsets, bind_group)) =
current_bind_group_0.clone()
{
new_commands.push(ArcComputeCommand::SetBindGroup {
index: 0,
num_dynamic_offsets,
bind_group,
});
} else {
new_commands.push(ArcComputeCommand::UnsetBindGroup { index: 0 });
}
new_commands.push(ArcComputeCommand::DispatchIndirect {
buffer: dst_buffer,
offset: 0,
});
} else {
new_commands.push(command.clone())
}
}
command => new_commands.push(command.clone()),
}
}
Ok(new_commands)
}
}

// Recording a compute pass.
Expand Down
Loading

0 comments on commit 018b23b

Please sign in to comment.