Skip to content

Fix HLSL single scalar loads #7104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -104,7 +104,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).

#### Dx12

- Fix HLSL storage format generation. By @Vecvec in [#6993](https://github.com/gfx-rs/wgpu/pull/6993)
- Fix HLSL storage format generation. By @Vecvec in [#6993](https://github.com/gfx-rs/wgpu/pull/6993) and [#7104](https://github.com/gfx-rs/wgpu/pull/7104)
- Fix 3D storage texture bindings. By @SparkyPotato in [#7071](https://github.com/gfx-rs/wgpu/pull/7071)

#### WebGPU
106 changes: 101 additions & 5 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ use super::{
writer::{EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION},
BackendResult,
};
use crate::{arena::Handle, proc::NameKey};
use crate::{arena::Handle, proc::NameKey, ScalarKind};
use std::fmt::Write;

#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
@@ -128,6 +128,8 @@ impl From<crate::ImageQuery> for ImageQuery {
}
}

pub(super) const IMAGE_STORAGE_LOAD_SCALAR_WRAPPER: &str = "LoadedStorageValueFrom";

impl<W: Write> super::Writer<'_, W> {
pub(super) fn write_image_type(
&mut self,
@@ -513,6 +515,60 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

/// Writes the conversion from a single length storage texture load to a vec4 with the loaded
/// scalar in its `x` component, 1 in its `a` component and 0 everywhere else.
fn write_loaded_scalar_to_storage_loaded_value(
&mut self,
scalar_type: crate::Scalar,
) -> BackendResult {
const ARGUMENT_VARIABLE_NAME: &str = "arg";
const RETURN_VARIABLE_NAME: &str = "ret";

let zero;
let one;
match scalar_type.kind {
ScalarKind::Sint => {
assert_eq!(
scalar_type.width, 4,
"Scalar {scalar_type:?} is not a result from any storage format"
);
zero = "0";
one = "1";
}
ScalarKind::Uint => match scalar_type.width {
4 => {
zero = "0u";
one = "1u";
}
8 => {
zero = "0uL";
one = "1uL"
}
_ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
},
ScalarKind::Float => {
assert_eq!(
scalar_type.width, 4,
"Scalar {scalar_type:?} is not a result from any storage format"
);
zero = "0.0";
one = "1.0";
}
_ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
}

let ty = scalar_type.to_hlsl_str()?;
writeln!(
self.out,
"{ty}4 {IMAGE_STORAGE_LOAD_SCALAR_WRAPPER}{ty}({ty} {ARGUMENT_VARIABLE_NAME}) {{\
{ty}4 {RETURN_VARIABLE_NAME} = {ty}4({ARGUMENT_VARIABLE_NAME}, {zero}, {zero}, {one});\
return {RETURN_VARIABLE_NAME};\
}}"
)?;

Ok(())
}

pub(super) fn write_wrapped_struct_matrix_get_function_name(
&mut self,
access: WrappedStructMatrixAccess,
@@ -848,11 +904,12 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

/// Helper function that writes compose wrapped functions
pub(super) fn write_wrapped_compose_functions(
/// Helper function that writes wrapped functions for expressions in a function
pub(super) fn write_wrapped_expression_functions(
&mut self,
module: &crate::Module,
expressions: &crate::Arena<crate::Expression>,
context: Option<&FunctionCtx>,
) -> BackendResult {
for (handle, _) in expressions.iter() {
match expressions[handle] {
@@ -867,6 +924,23 @@ impl<W: Write> super::Writer<'_, W> {
_ => {}
};
}
crate::Expression::ImageLoad { image, .. } => {
// This can only happen in a function as this is not a valid const expression
match *context.as_ref().unwrap().resolve_type(image, &module.types) {
crate::TypeInner::Image {
class: crate::ImageClass::Storage { format, .. },
..
} => {
if format.single_component() {
let scalar: crate::Scalar = format.into();
if self.wrapped.image_load_scalars.insert(scalar) {
self.write_loaded_scalar_to_storage_loaded_value(scalar)?;
}
}
}
_ => {}
}
}
crate::Expression::RayQueryGetIntersection { committed, .. } => {
if committed {
if !self.written_committed_intersection {
@@ -884,7 +958,7 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

// TODO: we could merge this with iteration in write_wrapped_compose_functions...
// TODO: we could merge this with iteration in write_wrapped_expression_functions...
//
/// Helper function that writes zero value wrapped functions
pub(super) fn write_wrapped_zero_value_functions(
@@ -1046,7 +1120,7 @@ impl<W: Write> super::Writer<'_, W> {
func_ctx: &FunctionCtx,
) -> BackendResult {
self.write_wrapped_math_functions(module, func_ctx)?;
self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?;
self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;

for (handle, _) in func_ctx.expressions.iter() {
@@ -1476,3 +1550,25 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}
}

impl crate::StorageFormat {
/// Returns `true` if there is just one component, otherwise `false`
pub(super) const fn single_component(&self) -> bool {
match *self {
crate::StorageFormat::R16Float
| crate::StorageFormat::R32Float
| crate::StorageFormat::R8Unorm
| crate::StorageFormat::R16Unorm
| crate::StorageFormat::R8Snorm
| crate::StorageFormat::R16Snorm
| crate::StorageFormat::R8Uint
| crate::StorageFormat::R16Uint
| crate::StorageFormat::R32Uint
| crate::StorageFormat::R8Sint
| crate::StorageFormat::R16Sint
| crate::StorageFormat::R32Sint
| crate::StorageFormat::R64Uint => true,
_ => false,
}
}
}
5 changes: 4 additions & 1 deletion naga/src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
@@ -908,4 +908,7 @@ pub const TYPES: &[&str] = &{
res
};

pub const RESERVED_PREFIXES: &[&str] = &["__dynamic_buffer_offsets"];
pub const RESERVED_PREFIXES: &[&str] = &[
"__dynamic_buffer_offsets",
super::help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
];
1 change: 1 addition & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
@@ -360,6 +360,7 @@ struct Wrapped {
zero_values: crate::FastHashSet<help::WrappedZeroValue>,
array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
image_queries: crate::FastHashSet<help::WrappedImageQuery>,
image_load_scalars: crate::FastHashSet<crate::Scalar>,
constructors: crate::FastHashSet<help::WrappedConstructor>,
struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
27 changes: 26 additions & 1 deletion naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{
help,
help::{
WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
WrappedZeroValue,
@@ -341,7 +342,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

self.write_special_functions(module)?;

self.write_wrapped_compose_functions(module, &module.global_expressions)?;
self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;

// Write all named constants
@@ -3152,6 +3153,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
sample,
level,
} => {
let mut wrapping_type = None;
match *func_ctx.resolve_type(image, &module.types) {
TypeInner::Image {
class: crate::ImageClass::Storage { format, .. },
..
} => {
if format.single_component() {
wrapping_type = Some(Scalar::from(format));
}
}
_ => {}
}
if let Some(scalar) = wrapping_type {
write!(
self.out,
"{}{}(",
help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
scalar.to_hlsl_str()?
)?;
}
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
self.write_expr(module, image, func_ctx)?;
write!(self.out, ".Load(")?;
@@ -3173,6 +3194,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// close bracket for Load function
write!(self.out, ")")?;

if wrapping_type.is_some() {
write!(self.out, ")")?;
}

// return x component if return type is scalar
if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
write!(self.out, ".x")?;
3 changes: 2 additions & 1 deletion naga/tests/out/hlsl/storage-textures.hlsl
Original file line number Diff line number Diff line change
@@ -5,10 +5,11 @@ RWTexture2D<float> s_r_w : register(u0, space1);
RWTexture2D<float4> s_rg_w : register(u1, space1);
RWTexture2D<float4> s_rgba_w : register(u2, space1);

float4 LoadedStorageValueFromfloat(float arg) {float4 ret = float4(arg, 0.0, 0.0, 1.0);return ret;}
[numthreads(1, 1, 1)]
void csLoad()
{
float4 phony = s_r_r.Load((0u).xx);
float4 phony = LoadedStorageValueFromfloat(s_r_r.Load((0u).xx));
float4 phony_1 = s_rg_r.Load((0u).xx);
float4 phony_2 = s_rgba_r.Load((0u).xx);
return;
128 changes: 125 additions & 3 deletions tests/tests/texture_binding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::time::Duration;
use wgpu::wgt::BufferDescriptor;
use wgpu::{
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, ComputePassDescriptor,
ComputePipelineDescriptor, DownlevelFlags, Extent3d, Features, TextureDescriptor,
TextureDimension, TextureFormat, TextureUsages,
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferUsages,
ComputePassDescriptor, ComputePipelineDescriptor, DownlevelFlags, Extent3d, Features, Maintain,
MapMode, Origin3d, TexelCopyBufferInfo, TexelCopyBufferLayout, TexelCopyTextureInfo,
TextureAspect, TextureDescriptor, TextureDimension, TextureFormat, TextureUsages,
};
use wgpu_macros::gpu_test;
use wgpu_test::{GpuTestConfiguration, TestParameters, TestingContext};
@@ -62,3 +65,122 @@ fn texture_binding(ctx: TestingContext) {
}
ctx.queue.submit([encoder.finish()]);
}

#[gpu_test]
static SINGLE_SCALAR_LOAD: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.test_features_limits()
.downlevel_flags(DownlevelFlags::WEBGPU_TEXTURE_FORMAT_SUPPORT)
.features(Features::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES),
)
.run_sync(single_scalar_load);

fn single_scalar_load(ctx: TestingContext) {
let texture_read = ctx.device.create_texture(&TextureDescriptor {
label: None,
size: Extent3d {
width: 1,
height: 1,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: TextureDimension::D2,
format: TextureFormat::R32Float,
usage: TextureUsages::STORAGE_BINDING,
view_formats: &[],
});
let texture_write = ctx.device.create_texture(&TextureDescriptor {
label: None,
size: Extent3d {
width: 1,
height: 1,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: TextureDimension::D2,
format: TextureFormat::Rgba32Float,
usage: TextureUsages::STORAGE_BINDING | TextureUsages::COPY_SRC,
view_formats: &[],
});
let buffer = ctx.device.create_buffer(&BufferDescriptor {
label: None,
size: size_of::<[f32; 4]>() as wgpu::BufferAddress,
usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let shader = ctx
.device
.create_shader_module(include_wgsl!("single_scalar.wgsl"));
let pipeline = ctx
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: None,
layout: None,
module: &shader,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});
let bind = ctx.device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&texture_write.create_view(&Default::default()),
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&texture_read.create_view(&Default::default()),
),
},
],
});

let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
encoder.copy_texture_to_buffer(
TexelCopyTextureInfo {
texture: &texture_write,
mip_level: 0,
origin: Origin3d::ZERO,
aspect: TextureAspect::All,
},
TexelCopyBufferInfo {
buffer: &buffer,
layout: TexelCopyBufferLayout {
offset: 0,
bytes_per_row: None,
rows_per_image: None,
},
},
Extent3d {
width: 1,
height: 1,
depth_or_array_layers: 1,
},
);
ctx.queue.submit([encoder.finish()]);
let (send, recv) = std::sync::mpsc::channel();
buffer.slice(..).map_async(MapMode::Read, move |res| {
res.unwrap();
send.send(()).expect("Thread should wait for receive");
});
// Poll to run map.
ctx.device.poll(Maintain::Wait);
recv.recv_timeout(Duration::from_secs(10))
.expect("mapping should not take this long");
let val = *bytemuck::from_bytes::<[f32; 4]>(&buffer.slice(..).get_mapped_range());
assert_eq!(val, [0.0, 0.0, 0.0, 1.0]);
}
Loading