Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

[msl-out] Fix packed vec3 stores #1816

Merged
merged 5 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,6 @@ impl<'a> ExpressionContext<'a> {
index::access_needs_check(base, index, self.module, self.function, self.info)
}

// Because packed vectors such as `packed_float3` cannot be directly loaded,
// we convert them to unpacked vectors like `float3` on load.
fn get_packed_vec_kind(
&self,
expr_handle: Handle<crate::Expression>,
Expand Down Expand Up @@ -1917,16 +1915,14 @@ impl<W: Write> Writer<W> {
write!(self.out, ".{}", name)?;
}
crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(base);
//Note: this doesn't work for left-hand side
if let Some(scalar_kind) = wrap_packed_vec_scalar_kind {
write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
self.put_access_chain(base, policy, context)?;
write!(self.out, ")")?;
self.put_access_chain(base, policy, context)?;
// Prior to Metal v2.1 component access for packed vectors wasn't available
// however array indexing is
if context.get_packed_vec_kind(base).is_some() {
write!(self.out, "[{}]", index)?;
} else {
self.put_access_chain(base, policy, context)?;
write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
}
write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
}
_ => {
self.put_subscripted_access_chain(
Expand Down Expand Up @@ -2052,7 +2048,6 @@ impl<W: Write> Writer<W> {
policy: index::BoundsCheckPolicy,
context: &ExpressionContext,
) -> BackendResult {
let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(pointer);
let is_atomic = match *context.resolve_type(pointer) {
crate::TypeInner::Pointer { base, .. } => match context.module.types[base].inner {
crate::TypeInner::Atomic { .. } => true,
Expand All @@ -2061,11 +2056,7 @@ impl<W: Write> Writer<W> {
_ => false,
};

if let Some(scalar_kind) = wrap_packed_vec_scalar_kind {
write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ")")?;
} else if is_atomic {
if is_atomic {
write!(
self.out,
"{}::atomic_load_explicit({}",
Expand Down
12 changes: 12 additions & 0 deletions tests/in/globals.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,22 @@ var<uniform> float_vecs: array<vec4<f32>, 20>;
fn main() {
wg[3] = alignment.v1;
wg[2] = alignment.v3.x;
var _ = alignment.v3;
var _ = alignment.v3.zx;
alignment.v1 = 4.0;
wg[1] = f32(arrayLength(&dummy));
atomicStore(&at, 2u);

alignment.v3 = vec3<f32>(1.0);
var idx = 1;
alignment.v3.x = 1.0;
alignment.v3[0] = 2.0;
alignment.v3[idx] = 3.0;

let m = mat3x3<f32>();
let _ = alignment.v3 * m;
let _ = m * alignment.v3;

// Valid, Foo and at is in function scope
var Foo: f32 = 1.0;
var at: bool = true;
Expand Down
17 changes: 16 additions & 1 deletion tests/out/glsl/globals.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,30 @@ layout(std430) readonly buffer type_6_block_1Compute { vec2 _group_0_binding_2_c


void main() {
vec3 unnamed = vec3(0.0);
vec2 unnamed_1 = vec2(0.0);
int idx = 1;
float Foo_1 = 1.0;
bool at = true;
float _e9 = _group_0_binding_1_cs.v1_;
wg[3] = _e9;
float _e14 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e14;
vec3 _e16 = _group_0_binding_1_cs.v3_;
unnamed = _e16;
vec3 _e19 = _group_0_binding_1_cs.v3_;
unnamed_1 = _e19.zx;
_group_0_binding_1_cs.v1_ = 4.0;
wg[1] = float(uint(_group_0_binding_2_cs.length()));
at_1 = 2u;
return;
_group_0_binding_1_cs.v3_ = vec3(1.0);
_group_0_binding_1_cs.v3_.x = 1.0;
_group_0_binding_1_cs.v3_.x = 2.0;
int _e42 = idx;
_group_0_binding_1_cs.v3_[_e42] = 3.0;
vec3 _e47 = _group_0_binding_1_cs.v3_;
vec3 unnamed_2 = (_e47 * mat3x3(vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0)));
vec3 _e50 = _group_0_binding_1_cs.v3_;
vec3 unnamed_3 = (mat3x3(vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0)) * _e50);
}

17 changes: 16 additions & 1 deletion tests/out/hlsl/globals.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,30 @@ uint NagaBufferLength(ByteAddressBuffer buffer)
[numthreads(1, 1, 1)]
void main()
{
float3 unnamed = (float3)0;
float2 unnamed_1 = (float2)0;
int idx = 1;
float Foo_1 = 1.0;
bool at = true;

float _expr9 = asfloat(alignment.Load(12));
wg[3] = _expr9;
float _expr14 = asfloat(alignment.Load(0+0));
wg[2] = _expr14;
float3 _expr16 = asfloat(alignment.Load3(0));
unnamed = _expr16;
float3 _expr19 = asfloat(alignment.Load3(0));
unnamed_1 = _expr19.zx;
alignment.Store(12, asuint(4.0));
wg[1] = float(((NagaBufferLength(dummy) - 0) / 8));
at_1 = 2u;
return;
alignment.Store3(0, asuint(float3(1.0.xxx)));
alignment.Store(0+0, asuint(1.0));
alignment.Store(0+0, asuint(2.0));
int _expr42 = idx;
alignment.Store(_expr42*4+0, asuint(3.0));
float3 _expr47 = asfloat(alignment.Load3(0));
float3 unnamed_2 = mul(float3x3(float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0)), _expr47);
float3 _expr50 = asfloat(alignment.Load3(0));
float3 unnamed_3 = mul(_expr50, float3x3(float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0)));
}
21 changes: 19 additions & 2 deletions tests/out/msl/globals.msl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ typedef metal::float2 type_6[1];
struct type_8 {
metal::float4 inner[20];
};
constant metal::float3 const_type_4_ = {0.0, 0.0, 0.0};
constant metal::float3x3 const_type_10_ = {const_type_4_, const_type_4_, const_type_4_};

kernel void main_(
threadgroup type_2& wg
Expand All @@ -28,14 +30,29 @@ kernel void main_(
, device type_6 const& dummy [[user(fake0)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
metal::float3 unnamed;
metal::float2 unnamed_1;
int idx = 1;
float Foo_1 = 1.0;
bool at = true;
float _e9 = alignment.v1_;
wg.inner[3] = _e9;
float _e14 = metal::float3(alignment.v3_).x;
float _e14 = alignment.v3_[0];
wg.inner[2] = _e14;
metal::float3 _e16 = alignment.v3_;
unnamed = _e16;
metal::float3 _e19 = alignment.v3_;
unnamed_1 = _e19.zx;
alignment.v1_ = 4.0;
wg.inner[1] = static_cast<float>(1 + (_buffer_sizes.size3 - 0 - 8) / 8);
metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed);
return;
alignment.v3_ = metal::float3(1.0);
alignment.v3_[0] = 1.0;
alignment.v3_[0] = 2.0;
int _e42 = idx;
alignment.v3_[_e42] = 3.0;
metal::float3 _e47 = alignment.v3_;
metal::float3 unnamed_2 = _e47 * const_type_10_;
metal::float3 _e50 = alignment.v3_;
metal::float3 unnamed_3 = const_type_10_ * _e50;
}
194 changes: 115 additions & 79 deletions tests/out/spv/globals.spvasm
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 68
; Bound: 98
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %42 "main"
OpExecutionMode %42 LocalSize 1 1 1
OpDecorate %17 ArrayStride 4
OpMemberDecorate %19 0 Offset 0
OpMemberDecorate %19 1 Offset 12
OpDecorate %21 ArrayStride 8
OpDecorate %23 ArrayStride 16
OpDecorate %28 DescriptorSet 0
OpDecorate %28 Binding 1
OpDecorate %29 Block
OpMemberDecorate %29 0 Offset 0
OpDecorate %31 NonWritable
OpDecorate %31 DescriptorSet 0
OpDecorate %31 Binding 2
OpDecorate %32 Block
OpMemberDecorate %32 0 Offset 0
OpDecorate %34 DescriptorSet 0
OpDecorate %34 Binding 3
OpDecorate %35 Block
OpMemberDecorate %35 0 Offset 0
OpEntryPoint GLCompute %55 "main"
OpExecutionMode %55 LocalSize 1 1 1
OpDecorate %21 ArrayStride 4
OpMemberDecorate %23 0 Offset 0
OpMemberDecorate %23 1 Offset 12
OpDecorate %25 ArrayStride 8
OpDecorate %27 ArrayStride 16
OpDecorate %35 DescriptorSet 0
OpDecorate %35 Binding 1
OpDecorate %36 Block
OpMemberDecorate %36 0 Offset 0
OpDecorate %38 NonWritable
OpDecorate %38 DescriptorSet 0
OpDecorate %38 Binding 2
OpDecorate %39 Block
OpMemberDecorate %39 0 Offset 0
OpDecorate %41 DescriptorSet 0
OpDecorate %41 Binding 3
OpDecorate %42 Block
OpMemberDecorate %42 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeBool
%3 = OpConstantTrue %4
Expand All @@ -40,63 +40,99 @@ OpMemberDecorate %35 0 Offset 0
%13 = OpConstant %8 1
%14 = OpConstant %6 2
%15 = OpConstant %12 1.0
%16 = OpConstantTrue %4
%17 = OpTypeArray %12 %5
%18 = OpTypeVector %12 3
%19 = OpTypeStruct %18 %12
%20 = OpTypeVector %12 2
%21 = OpTypeRuntimeArray %20
%22 = OpTypeVector %12 4
%23 = OpTypeArray %22 %7
%25 = OpTypePointer Workgroup %17
%24 = OpVariable %25 Workgroup
%27 = OpTypePointer Workgroup %6
%26 = OpVariable %27 Workgroup
%29 = OpTypeStruct %19
%30 = OpTypePointer StorageBuffer %29
%28 = OpVariable %30 StorageBuffer
%32 = OpTypeStruct %21
%33 = OpTypePointer StorageBuffer %32
%31 = OpVariable %33 StorageBuffer
%35 = OpTypeStruct %23
%36 = OpTypePointer Uniform %35
%34 = OpVariable %36 Uniform
%38 = OpTypePointer Function %12
%40 = OpTypePointer Function %4
%43 = OpTypeFunction %2
%44 = OpTypePointer StorageBuffer %19
%45 = OpConstant %6 0
%47 = OpTypePointer StorageBuffer %21
%49 = OpTypePointer Uniform %23
%51 = OpTypePointer Workgroup %12
%52 = OpTypePointer StorageBuffer %12
%53 = OpConstant %6 1
%56 = OpConstant %6 3
%58 = OpTypePointer StorageBuffer %18
%59 = OpTypePointer StorageBuffer %12
%67 = OpConstant %6 256
%42 = OpFunction %2 None %43
%41 = OpLabel
%37 = OpVariable %38 Function %15
%39 = OpVariable %40 Function %16
%46 = OpAccessChain %44 %28 %45
%48 = OpAccessChain %47 %31 %45
OpBranch %50
%50 = OpLabel
%54 = OpAccessChain %52 %46 %53
%55 = OpLoad %12 %54
%57 = OpAccessChain %51 %24 %56
OpStore %57 %55
%60 = OpAccessChain %59 %46 %45 %45
%61 = OpLoad %12 %60
%62 = OpAccessChain %51 %24 %14
OpStore %62 %61
%63 = OpAccessChain %52 %46 %53
OpStore %63 %11
%64 = OpArrayLength %6 %31 0
%65 = OpConvertUToF %12 %64
%66 = OpAccessChain %51 %24 %53
OpStore %66 %65
OpAtomicStore %26 %10 %67 %14
%16 = OpConstant %8 0
%17 = OpConstant %12 2.0
%18 = OpConstant %12 3.0
%19 = OpConstant %12 0.0
%20 = OpConstantTrue %4
%21 = OpTypeArray %12 %5
%22 = OpTypeVector %12 3
%23 = OpTypeStruct %22 %12
%24 = OpTypeVector %12 2
%25 = OpTypeRuntimeArray %24
%26 = OpTypeVector %12 4
%27 = OpTypeArray %26 %7
%28 = OpTypeMatrix %22 3
%29 = OpConstantComposite %22 %19 %19 %19
%30 = OpConstantComposite %28 %29 %29 %29
%32 = OpTypePointer Workgroup %21
%31 = OpVariable %32 Workgroup
%34 = OpTypePointer Workgroup %6
%33 = OpVariable %34 Workgroup
%36 = OpTypeStruct %23
%37 = OpTypePointer StorageBuffer %36
%35 = OpVariable %37 StorageBuffer
%39 = OpTypeStruct %25
%40 = OpTypePointer StorageBuffer %39
%38 = OpVariable %40 StorageBuffer
%42 = OpTypeStruct %27
%43 = OpTypePointer Uniform %42
%41 = OpVariable %43 Uniform
%45 = OpTypePointer Function %22
%47 = OpTypePointer Function %24
%49 = OpTypePointer Function %8
%51 = OpTypePointer Function %12
%53 = OpTypePointer Function %4
%56 = OpTypeFunction %2
%57 = OpTypePointer StorageBuffer %23
%58 = OpConstant %6 0
%60 = OpTypePointer StorageBuffer %25
%62 = OpTypePointer Uniform %27
%64 = OpTypePointer Workgroup %12
%65 = OpTypePointer StorageBuffer %12
%66 = OpConstant %6 1
%69 = OpConstant %6 3
%71 = OpTypePointer StorageBuffer %22
%72 = OpTypePointer StorageBuffer %12
%85 = OpConstant %6 256
%55 = OpFunction %2 None %56
%54 = OpLabel
%52 = OpVariable %53 Function %20
%46 = OpVariable %47 Function
%50 = OpVariable %51 Function %15
%44 = OpVariable %45 Function
%48 = OpVariable %49 Function %13
%59 = OpAccessChain %57 %35 %58
%61 = OpAccessChain %60 %38 %58
OpBranch %63
%63 = OpLabel
%67 = OpAccessChain %65 %59 %66
%68 = OpLoad %12 %67
%70 = OpAccessChain %64 %31 %69
OpStore %70 %68
%73 = OpAccessChain %72 %59 %58 %58
%74 = OpLoad %12 %73
%75 = OpAccessChain %64 %31 %14
OpStore %75 %74
%76 = OpAccessChain %71 %59 %58
%77 = OpLoad %22 %76
OpStore %44 %77
%78 = OpAccessChain %71 %59 %58
%79 = OpLoad %22 %78
%80 = OpVectorShuffle %24 %79 %79 2 0
OpStore %46 %80
%81 = OpAccessChain %65 %59 %66
OpStore %81 %11
%82 = OpArrayLength %6 %38 0
%83 = OpConvertUToF %12 %82
%84 = OpAccessChain %64 %31 %66
OpStore %84 %83
OpAtomicStore %33 %10 %85 %14
%86 = OpCompositeConstruct %22 %15 %15 %15
%87 = OpAccessChain %71 %59 %58
OpStore %87 %86
%88 = OpAccessChain %72 %59 %58 %58
OpStore %88 %15
%89 = OpAccessChain %72 %59 %58 %58
OpStore %89 %17
%90 = OpLoad %8 %48
%91 = OpAccessChain %72 %59 %58 %90
OpStore %91 %18
%92 = OpAccessChain %71 %59 %58
%93 = OpLoad %22 %92
%94 = OpVectorTimesMatrix %22 %93 %30
%95 = OpAccessChain %71 %59 %58
%96 = OpLoad %22 %95
%97 = OpMatrixTimesVector %22 %30 %96
OpReturn
OpFunctionEnd
Loading