Skip to content

Commit 5af9e30

Browse files
authoredFeb 12, 2025··
Fix HLSL single scalar loads (#7104)
1 parent 0dd6a1c commit 5af9e30

File tree

8 files changed

+268
-12
lines changed

8 files changed

+268
-12
lines changed
 

‎CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
104104

105105
#### Dx12
106106

107-
- Fix HLSL storage format generation. By @Vecvec in [#6993](https://github.com/gfx-rs/wgpu/pull/6993)
107+
- Fix HLSL storage format generation. By @Vecvec in [#6993](https://github.com/gfx-rs/wgpu/pull/6993) and [#7104](https://github.com/gfx-rs/wgpu/pull/7104)
108108
- Fix 3D storage texture bindings. By @SparkyPotato in [#7071](https://github.com/gfx-rs/wgpu/pull/7071)
109109

110110
#### WebGPU

‎naga/src/back/hlsl/help.rs

+101-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use super::{
3131
writer::{EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION},
3232
BackendResult,
3333
};
34-
use crate::{arena::Handle, proc::NameKey};
34+
use crate::{arena::Handle, proc::NameKey, ScalarKind};
3535
use std::fmt::Write;
3636

3737
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
@@ -128,6 +128,8 @@ impl From<crate::ImageQuery> for ImageQuery {
128128
}
129129
}
130130

131+
pub(super) const IMAGE_STORAGE_LOAD_SCALAR_WRAPPER: &str = "LoadedStorageValueFrom";
132+
131133
impl<W: Write> super::Writer<'_, W> {
132134
pub(super) fn write_image_type(
133135
&mut self,
@@ -513,6 +515,60 @@ impl<W: Write> super::Writer<'_, W> {
513515
Ok(())
514516
}
515517

518+
/// Writes the conversion from a single length storage texture load to a vec4 with the loaded
519+
/// scalar in its `x` component, 1 in its `a` component and 0 everywhere else.
520+
fn write_loaded_scalar_to_storage_loaded_value(
521+
&mut self,
522+
scalar_type: crate::Scalar,
523+
) -> BackendResult {
524+
const ARGUMENT_VARIABLE_NAME: &str = "arg";
525+
const RETURN_VARIABLE_NAME: &str = "ret";
526+
527+
let zero;
528+
let one;
529+
match scalar_type.kind {
530+
ScalarKind::Sint => {
531+
assert_eq!(
532+
scalar_type.width, 4,
533+
"Scalar {scalar_type:?} is not a result from any storage format"
534+
);
535+
zero = "0";
536+
one = "1";
537+
}
538+
ScalarKind::Uint => match scalar_type.width {
539+
4 => {
540+
zero = "0u";
541+
one = "1u";
542+
}
543+
8 => {
544+
zero = "0uL";
545+
one = "1uL"
546+
}
547+
_ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
548+
},
549+
ScalarKind::Float => {
550+
assert_eq!(
551+
scalar_type.width, 4,
552+
"Scalar {scalar_type:?} is not a result from any storage format"
553+
);
554+
zero = "0.0";
555+
one = "1.0";
556+
}
557+
_ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
558+
}
559+
560+
let ty = scalar_type.to_hlsl_str()?;
561+
writeln!(
562+
self.out,
563+
"{ty}4 {IMAGE_STORAGE_LOAD_SCALAR_WRAPPER}{ty}({ty} {ARGUMENT_VARIABLE_NAME}) {{\
564+
{ty}4 {RETURN_VARIABLE_NAME} = {ty}4({ARGUMENT_VARIABLE_NAME}, {zero}, {zero}, {one});\
565+
return {RETURN_VARIABLE_NAME};\
566+
}}"
567+
)?;
568+
569+
Ok(())
570+
}
571+
516572
pub(super) fn write_wrapped_struct_matrix_get_function_name(
517573
&mut self,
518574
access: WrappedStructMatrixAccess,
@@ -848,11 +904,12 @@ impl<W: Write> super::Writer<'_, W> {
848904
Ok(())
849905
}
850906

851-
/// Helper function that writes compose wrapped functions
852-
pub(super) fn write_wrapped_compose_functions(
907+
/// Helper function that writes wrapped functions for expressions in a function
908+
pub(super) fn write_wrapped_expression_functions(
853909
&mut self,
854910
module: &crate::Module,
855911
expressions: &crate::Arena<crate::Expression>,
912+
context: Option<&FunctionCtx>,
856913
) -> BackendResult {
857914
for (handle, _) in expressions.iter() {
858915
match expressions[handle] {
@@ -867,6 +924,23 @@ impl<W: Write> super::Writer<'_, W> {
867924
_ => {}
868925
};
869926
}
927+
crate::Expression::ImageLoad { image, .. } => {
928+
// This can only happen in a function as this is not a valid const expression
929+
match *context.as_ref().unwrap().resolve_type(image, &module.types) {
930+
crate::TypeInner::Image {
931+
class: crate::ImageClass::Storage { format, .. },
932+
..
933+
} => {
934+
if format.single_component() {
935+
let scalar: crate::Scalar = format.into();
936+
if self.wrapped.image_load_scalars.insert(scalar) {
937+
self.write_loaded_scalar_to_storage_loaded_value(scalar)?;
938+
}
939+
}
940+
}
941+
_ => {}
942+
}
943+
}
870944
crate::Expression::RayQueryGetIntersection { committed, .. } => {
871945
if committed {
872946
if !self.written_committed_intersection {
@@ -884,7 +958,7 @@ impl<W: Write> super::Writer<'_, W> {
884958
Ok(())
885959
}
886960

887-
// TODO: we could merge this with iteration in write_wrapped_compose_functions...
961+
// TODO: we could merge this with iteration in write_wrapped_expression_functions...
888962
//
889963
/// Helper function that writes zero value wrapped functions
890964
pub(super) fn write_wrapped_zero_value_functions(
@@ -1046,7 +1120,7 @@ impl<W: Write> super::Writer<'_, W> {
10461120
func_ctx: &FunctionCtx,
10471121
) -> BackendResult {
10481122
self.write_wrapped_math_functions(module, func_ctx)?;
1049-
self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
1123+
self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?;
10501124
self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
10511125

10521126
for (handle, _) in func_ctx.expressions.iter() {
@@ -1476,3 +1550,25 @@ impl<W: Write> super::Writer<'_, W> {
14761550
Ok(())
14771551
}
14781552
}
1553+
1554+
impl crate::StorageFormat {
1555+
/// Returns `true` if there is just one component, otherwise `false`
1556+
pub(super) const fn single_component(&self) -> bool {
1557+
match *self {
1558+
crate::StorageFormat::R16Float
1559+
| crate::StorageFormat::R32Float
1560+
| crate::StorageFormat::R8Unorm
1561+
| crate::StorageFormat::R16Unorm
1562+
| crate::StorageFormat::R8Snorm
1563+
| crate::StorageFormat::R16Snorm
1564+
| crate::StorageFormat::R8Uint
1565+
| crate::StorageFormat::R16Uint
1566+
| crate::StorageFormat::R32Uint
1567+
| crate::StorageFormat::R8Sint
1568+
| crate::StorageFormat::R16Sint
1569+
| crate::StorageFormat::R32Sint
1570+
| crate::StorageFormat::R64Uint => true,
1571+
_ => false,
1572+
}
1573+
}
1574+
}

‎naga/src/back/hlsl/keywords.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -908,4 +908,7 @@ pub const TYPES: &[&str] = &{
908908
res
909909
};
910910

911-
pub const RESERVED_PREFIXES: &[&str] = &["__dynamic_buffer_offsets"];
911+
pub const RESERVED_PREFIXES: &[&str] = &[
912+
"__dynamic_buffer_offsets",
913+
super::help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
914+
];

‎naga/src/back/hlsl/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ struct Wrapped {
360360
zero_values: crate::FastHashSet<help::WrappedZeroValue>,
361361
array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
362362
image_queries: crate::FastHashSet<help::WrappedImageQuery>,
363+
image_load_scalars: crate::FastHashSet<crate::Scalar>,
363364
constructors: crate::FastHashSet<help::WrappedConstructor>,
364365
struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
365366
mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,

‎naga/src/back/hlsl/writer.rs

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::{
2+
help,
23
help::{
34
WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
45
WrappedZeroValue,
@@ -341,7 +342,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
341342

342343
self.write_special_functions(module)?;
343344

344-
self.write_wrapped_compose_functions(module, &module.global_expressions)?;
345+
self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
345346
self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
346347

347348
// Write all named constants
@@ -3152,6 +3153,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
31523153
sample,
31533154
level,
31543155
} => {
3156+
let mut wrapping_type = None;
3157+
match *func_ctx.resolve_type(image, &module.types) {
3158+
TypeInner::Image {
3159+
class: crate::ImageClass::Storage { format, .. },
3160+
..
3161+
} => {
3162+
if format.single_component() {
3163+
wrapping_type = Some(Scalar::from(format));
3164+
}
3165+
}
3166+
_ => {}
3167+
}
3168+
if let Some(scalar) = wrapping_type {
3169+
write!(
3170+
self.out,
3171+
"{}{}(",
3172+
help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
3173+
scalar.to_hlsl_str()?
3174+
)?;
3175+
}
31553176
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
31563177
self.write_expr(module, image, func_ctx)?;
31573178
write!(self.out, ".Load(")?;
@@ -3173,6 +3194,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
31733194
// close bracket for Load function
31743195
write!(self.out, ")")?;
31753196

3197+
if wrapping_type.is_some() {
3198+
write!(self.out, ")")?;
3199+
}
3200+
31763201
// return x component if return type is scalar
31773202
if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
31783203
write!(self.out, ".x")?;

‎naga/tests/out/hlsl/storage-textures.hlsl

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ RWTexture2D<float> s_r_w : register(u0, space1);
55
RWTexture2D<float4> s_rg_w : register(u1, space1);
66
RWTexture2D<float4> s_rgba_w : register(u2, space1);
77

8+
float4 LoadedStorageValueFromfloat(float arg) {float4 ret = float4(arg, 0.0, 0.0, 1.0);return ret;}
89
[numthreads(1, 1, 1)]
910
void csLoad()
1011
{
11-
float4 phony = s_r_r.Load((0u).xx);
12+
float4 phony = LoadedStorageValueFromfloat(s_r_r.Load((0u).xx));
1213
float4 phony_1 = s_rg_r.Load((0u).xx);
1314
float4 phony_2 = s_rgba_r.Load((0u).xx);
1415
return;

‎tests/tests/texture_binding/mod.rs

+125-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
use std::time::Duration;
2+
use wgpu::wgt::BufferDescriptor;
13
use wgpu::{
2-
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, ComputePassDescriptor,
3-
ComputePipelineDescriptor, DownlevelFlags, Extent3d, Features, TextureDescriptor,
4-
TextureDimension, TextureFormat, TextureUsages,
4+
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferUsages,
5+
ComputePassDescriptor, ComputePipelineDescriptor, DownlevelFlags, Extent3d, Features, Maintain,
6+
MapMode, Origin3d, TexelCopyBufferInfo, TexelCopyBufferLayout, TexelCopyTextureInfo,
7+
TextureAspect, TextureDescriptor, TextureDimension, TextureFormat, TextureUsages,
58
};
69
use wgpu_macros::gpu_test;
710
use wgpu_test::{GpuTestConfiguration, TestParameters, TestingContext};
@@ -62,3 +65,122 @@ fn texture_binding(ctx: TestingContext) {
6265
}
6366
ctx.queue.submit([encoder.finish()]);
6467
}
68+
69+
#[gpu_test]
70+
static SINGLE_SCALAR_LOAD: GpuTestConfiguration = GpuTestConfiguration::new()
71+
.parameters(
72+
TestParameters::default()
73+
.test_features_limits()
74+
.downlevel_flags(DownlevelFlags::WEBGPU_TEXTURE_FORMAT_SUPPORT)
75+
.features(Features::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES),
76+
)
77+
.run_sync(single_scalar_load);
78+
79+
fn single_scalar_load(ctx: TestingContext) {
80+
let texture_read = ctx.device.create_texture(&TextureDescriptor {
81+
label: None,
82+
size: Extent3d {
83+
width: 1,
84+
height: 1,
85+
depth_or_array_layers: 1,
86+
},
87+
mip_level_count: 1,
88+
sample_count: 1,
89+
dimension: TextureDimension::D2,
90+
format: TextureFormat::R32Float,
91+
usage: TextureUsages::STORAGE_BINDING,
92+
view_formats: &[],
93+
});
94+
let texture_write = ctx.device.create_texture(&TextureDescriptor {
95+
label: None,
96+
size: Extent3d {
97+
width: 1,
98+
height: 1,
99+
depth_or_array_layers: 1,
100+
},
101+
mip_level_count: 1,
102+
sample_count: 1,
103+
dimension: TextureDimension::D2,
104+
format: TextureFormat::Rgba32Float,
105+
usage: TextureUsages::STORAGE_BINDING | TextureUsages::COPY_SRC,
106+
view_formats: &[],
107+
});
108+
let buffer = ctx.device.create_buffer(&BufferDescriptor {
109+
label: None,
110+
size: size_of::<[f32; 4]>() as wgpu::BufferAddress,
111+
usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
112+
mapped_at_creation: false,
113+
});
114+
let shader = ctx
115+
.device
116+
.create_shader_module(include_wgsl!("single_scalar.wgsl"));
117+
let pipeline = ctx
118+
.device
119+
.create_compute_pipeline(&ComputePipelineDescriptor {
120+
label: None,
121+
layout: None,
122+
module: &shader,
123+
entry_point: None,
124+
compilation_options: Default::default(),
125+
cache: None,
126+
});
127+
let bind = ctx.device.create_bind_group(&BindGroupDescriptor {
128+
label: None,
129+
layout: &pipeline.get_bind_group_layout(0),
130+
entries: &[
131+
BindGroupEntry {
132+
binding: 0,
133+
resource: BindingResource::TextureView(
134+
&texture_write.create_view(&Default::default()),
135+
),
136+
},
137+
BindGroupEntry {
138+
binding: 1,
139+
resource: BindingResource::TextureView(
140+
&texture_read.create_view(&Default::default()),
141+
),
142+
},
143+
],
144+
});
145+
146+
let mut encoder = ctx.device.create_command_encoder(&Default::default());
147+
{
148+
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
149+
pass.set_pipeline(&pipeline);
150+
pass.set_bind_group(0, &bind, &[]);
151+
pass.dispatch_workgroups(1, 1, 1);
152+
}
153+
encoder.copy_texture_to_buffer(
154+
TexelCopyTextureInfo {
155+
texture: &texture_write,
156+
mip_level: 0,
157+
origin: Origin3d::ZERO,
158+
aspect: TextureAspect::All,
159+
},
160+
TexelCopyBufferInfo {
161+
buffer: &buffer,
162+
layout: TexelCopyBufferLayout {
163+
offset: 0,
164+
bytes_per_row: None,
165+
rows_per_image: None,
166+
},
167+
},
168+
Extent3d {
169+
width: 1,
170+
height: 1,
171+
depth_or_array_layers: 1,
172+
},
173+
);
174+
ctx.queue.submit([encoder.finish()]);
175+
let (send, recv) = std::sync::mpsc::channel();
176+
buffer.slice(..).map_async(MapMode::Read, move |res| {
177+
res.unwrap();
178+
send.send(()).expect("Thread should wait for receive");
179+
});
180+
// Poll to run map.
181+
ctx.device.poll(Maintain::Wait);
182+
recv.recv_timeout(Duration::from_secs(10))
183+
.expect("mapping should not take this long");
184+
let val = *bytemuck::from_bytes::<[f32; 4]>(&buffer.slice(..).get_mapped_range());
185+
assert_eq!(val, [0.0, 0.0, 0.0, 1.0]);
186+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@group(0) @binding(0)
2+
var tex_w: texture_storage_2d<rgba32float, write>;
3+
@group(0) @binding(1)
4+
var tex_r: texture_storage_2d<r32float, read>;
5+
6+
@compute @workgroup_size(1) fn csStore() {
7+
textureStore(tex_w, vec2u(0), textureLoad(tex_r, vec2u(0)));
8+
}

0 commit comments

Comments
 (0)
Please sign in to comment.