Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate that resources belong to the right device. #4207

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ By @teoxoy in [#4185](https://github.com/gfx-rs/wgpu/pull/4185)
- Add stub support for device destroy and device validity. By @bradwerth in [4163](https://github.com/gfx-rs/wgpu/pull/4163)
- Add trace-level logging for most entry points in wgpu-core By @nical in [4183](https://github.com/gfx-rs/wgpu/pull/4183)
- Add `Rgb10a2Uint` format. By @teoxoy in [4199](https://github.com/gfx-rs/wgpu/pull/4199)
- Validate that resources are used on the right device. By @nical in [4207](https://github.com/gfx-rs/wgpu/pull/4207)

#### Vulkan

Expand Down
6 changes: 3 additions & 3 deletions deno_webgpu/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ impl From<DeviceError> for WebGpuError {
match err {
DeviceError::Lost => WebGpuError::Lost,
DeviceError::OutOfMemory => WebGpuError::OutOfMemory,
DeviceError::ResourceCreationFailed | DeviceError::Invalid => {
WebGpuError::Validation(fmt_err(&err))
}
DeviceError::ResourceCreationFailed
| DeviceError::Invalid
| DeviceError::WrongDevice => WebGpuError::Validation(fmt_err(&err)),
}
}
}
Expand Down
1 change: 0 additions & 1 deletion tests/tests/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ fn device_initialization() {
}

#[test]
#[ignore]
fn device_mismatch() {
initialize_test(
// https://github.com/gfx-rs/wgpu/issues/3927
Expand Down
45 changes: 36 additions & 9 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ use crate::{
RenderCommand, RenderCommandError, StateChange,
},
device::{
AttachmentData, Device, MissingDownlevelFlags, MissingFeatures,
AttachmentData, Device, DeviceError, MissingDownlevelFlags, MissingFeatures,
RenderPassCompatibilityCheckType, RenderPassCompatibilityError, RenderPassContext,
},
error::{ErrorFormatter, PrettyError},
global::Global,
hal_api::HalApi,
hub::Token,
id,
id::DeviceId,
identity::GlobalIdentityHandlerFactory,
init_tracker::{MemoryInitKind, TextureInitRange, TextureInitTrackerAction},
pipeline::{self, PipelineFlags},
Expand Down Expand Up @@ -520,12 +519,12 @@ pub enum ColorAttachmentError {
/// Error encountered when performing a render pass.
#[derive(Clone, Debug, Error)]
pub enum RenderPassErrorInner {
#[error(transparent)]
Device(DeviceError),
#[error(transparent)]
ColorAttachment(#[from] ColorAttachmentError),
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Device {0:?} is invalid")]
InvalidDevice(DeviceId),
#[error("Attachment texture view {0:?} is invalid")]
InvalidAttachment(id::TextureViewId),
#[error("The format of the depth-stencil attachment ({0:?}) is not a depth-stencil format")]
Expand Down Expand Up @@ -658,6 +657,12 @@ impl From<MissingTextureUsageError> for RenderPassErrorInner {
}
}

impl From<DeviceError> for RenderPassErrorInner {
fn from(error: DeviceError) -> Self {
Self::Device(error)
}
}

/// Error encountered when performing a render pass.
#[derive(Clone, Debug, Error)]
#[error("{scope}")]
Expand Down Expand Up @@ -1351,12 +1356,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
});
}

let device = &device_guard[cmd_buf.device_id.value];
let device_id = cmd_buf.device_id.value;

let device = &device_guard[device_id];
if !device.is_valid() {
return Err(RenderPassErrorInner::InvalidDevice(
cmd_buf.device_id.value.0,
))
.map_pass_err(init_scope);
return Err(DeviceError::Invalid).map_pass_err(init_scope);
}
cmd_buf.encoder.open_pass(base.label);

Expand Down Expand Up @@ -1451,6 +1455,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.add_single(&*bind_group_guard, bind_group_id)
.ok_or(RenderCommandError::InvalidBindGroup(bind_group_id))
.map_pass_err(scope)?;

if bind_group.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

bind_group
.validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
.map_pass_err(scope)?;
Expand Down Expand Up @@ -1518,6 +1527,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.ok_or(RenderCommandError::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?;

if pipeline.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

info.context
.check_compatible(
&pipeline.pass_context,
Expand Down Expand Up @@ -1635,6 +1648,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.buffers
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDEX)
.map_pass_err(scope)?;

if buffer.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

check_buffer_usage(buffer.usage, BufferUsages::INDEX)
.map_pass_err(scope)?;
let buf_raw = buffer
Expand Down Expand Up @@ -1683,6 +1701,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.buffers
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::VERTEX)
.map_pass_err(scope)?;

if buffer.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

check_buffer_usage(buffer.usage, BufferUsages::VERTEX)
.map_pass_err(scope)?;
let buf_raw = buffer
Expand Down Expand Up @@ -2265,6 +2288,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.ok_or(RenderCommandError::InvalidRenderBundle(bundle_id))
.map_pass_err(scope)?;

if bundle.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

info.context
.check_compatible(
&bundle.context,
Expand Down
4 changes: 4 additions & 0 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
Err(..) => break binding_model::CreateBindGroupError::InvalidLayout,
};

if bind_group_layout.device_id.value.0 != device_id {
break DeviceError::WrongDevice.into();
}

let mut layout_id = id::Valid(desc.layout);
if let Some(id) = bind_group_layout.as_duplicate() {
layout_id = id;
Expand Down
2 changes: 2 additions & 0 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ pub enum DeviceError {
OutOfMemory,
#[error("Creation of a resource failed for a reason other than running out of memory.")]
ResourceCreationFailed,
#[error("Attempt to use a resource with a different device from the one that created it")]
WrongDevice,
}

impl From<hal::DeviceError> for DeviceError {
Expand Down
16 changes: 16 additions & 0 deletions wgpu-core/src/device/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

let result = self.queue_write_staging_buffer_impl(
queue_id,
device,
device_token,
&staging_buffer,
Expand Down Expand Up @@ -464,6 +465,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

let result = self.queue_write_staging_buffer_impl(
queue_id,
device,
device_token,
&staging_buffer,
Expand Down Expand Up @@ -531,6 +533,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

fn queue_write_staging_buffer_impl<A: HalApi>(
&self,
device_id: id::DeviceId,
device: &mut super::Device<A>,
device_token: &mut Token<super::Device<A>>,
staging_buffer: &StagingBuffer<A>,
Expand All @@ -551,6 +554,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.as_ref()
.ok_or(TransferError::InvalidBuffer(buffer_id))?;

if dst.device_id.value.0 != device_id {
return Err(DeviceError::WrongDevice.into());
}

let src_buffer_size = staging_buffer.size;
self.queue_validate_write_buffer_impl(dst, buffer_id, buffer_offset, src_buffer_size)?;

Expand Down Expand Up @@ -627,6 +634,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.get_mut(destination.texture)
.map_err(|_| TransferError::InvalidTexture(destination.texture))?;

if dst.device_id.value.0 != queue_id {
return Err(DeviceError::WrongDevice.into());
}

if !dst.desc.usage.contains(wgt::TextureUsages::COPY_DST) {
return Err(
TransferError::MissingCopyDstUsageFlag(None, Some(destination.texture)).into(),
Expand Down Expand Up @@ -1105,6 +1116,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
Some(cmdbuf) => cmdbuf,
None => continue,
};

if cmdbuf.device_id.value.0 != queue_id {
return Err(DeviceError::WrongDevice.into());
}

#[cfg(feature = "trace")]
if let Some(ref trace) = device.trace {
trace.lock().add(Action::Submit(
Expand Down
47 changes: 47 additions & 0 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,7 @@ impl<A: HalApi> Device<A> {
}

fn create_buffer_binding<'a>(
device_id: id::DeviceId,
bb: &binding_model::BufferBinding,
binding: u32,
decl: &wgt::BindGroupLayoutEntry,
Expand Down Expand Up @@ -1727,6 +1728,11 @@ impl<A: HalApi> Device<A> {
.buffers
.add_single(storage, bb.buffer_id, internal_use)
.ok_or(Error::InvalidBuffer(bb.buffer_id))?;

if buffer.device_id.value.0 != device_id {
return Err(DeviceError::WrongDevice.into());
}

check_buffer_usage(buffer.usage, pub_usage)?;
let raw_buffer = buffer
.raw
Expand Down Expand Up @@ -1797,6 +1803,7 @@ impl<A: HalApi> Device<A> {
}

fn create_texture_binding(
device_id: id::DeviceId,
view: &resource::TextureView<A>,
texture_guard: &Storage<resource::Texture<A>, id::TextureId>,
internal_use: hal::TextureUses,
Expand All @@ -1818,6 +1825,11 @@ impl<A: HalApi> Device<A> {
.ok_or(binding_model::CreateBindGroupError::InvalidTexture(
view.parent_id.value.0,
))?;

if texture.device_id.value.0 != device_id {
return Err(DeviceError::WrongDevice.into());
}

check_texture_usage(texture.desc.usage, pub_usage)?;

used_texture_ranges.push(TextureInitTrackerAction {
Expand Down Expand Up @@ -1889,6 +1901,7 @@ impl<A: HalApi> Device<A> {
let (res_index, count) = match entry.resource {
Br::Buffer(ref bb) => {
let bb = Self::create_buffer_binding(
self_id,
bb,
binding,
decl,
Expand All @@ -1911,6 +1924,7 @@ impl<A: HalApi> Device<A> {
let res_index = hal_buffers.len();
for bb in bindings_array.iter() {
let bb = Self::create_buffer_binding(
self_id,
bb,
binding,
decl,
Expand All @@ -1933,6 +1947,10 @@ impl<A: HalApi> Device<A> {
.add_single(&*sampler_guard, id)
.ok_or(Error::InvalidSampler(id))?;

if sampler.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

// Allowed sampler values for filtering and comparison
let (allowed_filtering, allowed_comparison) = match ty {
wgt::SamplerBindingType::Filtering => (None, false),
Expand Down Expand Up @@ -1981,6 +1999,11 @@ impl<A: HalApi> Device<A> {
.samplers
.add_single(&*sampler_guard, id)
.ok_or(Error::InvalidSampler(id))?;

if sampler.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

hal_samplers.push(&sampler.raw);
}

Expand All @@ -1998,6 +2021,7 @@ impl<A: HalApi> Device<A> {
"SampledTexture, ReadonlyStorageTexture or WriteonlyStorageTexture",
)?;
Self::create_texture_binding(
self_id,
view,
&texture_guard,
internal_use,
Expand Down Expand Up @@ -2026,6 +2050,7 @@ impl<A: HalApi> Device<A> {
Self::texture_use_parameters(binding, decl, view,
"SampledTextureArray, ReadonlyStorageTextureArray or WriteonlyStorageTextureArray")?;
Self::create_texture_binding(
self_id,
view,
&texture_guard,
internal_use,
Expand Down Expand Up @@ -2324,6 +2349,11 @@ impl<A: HalApi> Device<A> {
let Some(bind_group_layout) = try_get_bind_group_layout(bgl_guard, id) else {
return Err(Error::InvalidBindGroupLayout(id));
};

if bind_group_layout.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

count_validator.merge(&bind_group_layout.assume_deduplicated().count_validator);
}
count_validator
Expand Down Expand Up @@ -2457,6 +2487,10 @@ impl<A: HalApi> Device<A> {
.get(desc.stage.module)
.map_err(|_| validation::StageError::InvalidModule)?;

if shader_module.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

{
let flag = wgt::ShaderStages::COMPUTE;
let provided_layouts = match desc.layout {
Expand Down Expand Up @@ -2500,6 +2534,10 @@ impl<A: HalApi> Device<A> {
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;

if layout.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

let late_sized_buffer_groups =
Device::make_late_sized_buffer_groups(&shader_binding_sizes, layout, &*bgl_guard);

Expand Down Expand Up @@ -2843,11 +2881,20 @@ impl<A: HalApi> Device<A> {
}
})?;

if shader_module.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

let provided_layouts = match desc.layout {
Some(pipeline_layout_id) => {
let pipeline_layout = pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?;

if pipeline_layout.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

Some(Device::get_introspection_bind_group_layouts(
pipeline_layout,
&*bgl_guard,
Expand Down