Skip to content

Commit

Permalink
Allow non-structure buffer types
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Feb 2, 2022
1 parent bb604fd commit a00638f
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 131 deletions.
30 changes: 19 additions & 11 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, " {}", name)?;
if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
self.write_array_size(module, size)?;
}

if let Some(ref binding) = global.binding {
// this was already resolved earlier when we started evaluating an entry point.
Expand All @@ -597,20 +594,31 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ", space{}", bt.space)?;
}
write!(self.out, ")")?;
} else if global.space == crate::AddressSpace::Private {
write!(self.out, " = ")?;
if let Some(init) = global.init {
self.write_constant(module, init)?;
} else {
self.write_default_init(module, global.ty)?;
} else {
// need to write the array size if the type was emitted with `write_type`
if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
self.write_array_size(module, size)?;
}
if global.space == crate::AddressSpace::Private {
write!(self.out, " = ")?;
if let Some(init) = global.init {
self.write_constant(module, init)?;
} else {
self.write_default_init(module, global.ty)?;
}
}
}

if global.space == crate::AddressSpace::Uniform {
write!(self.out, " {{ ")?;
self.write_type(module, global.ty)?;
let name = &self.names[&NameKey::GlobalVariable(handle)];
writeln!(self.out, " {}; }}", name)?;
let sub_name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, " {}", sub_name)?;
// need to write the array size if the type was emitted with `write_type`
if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
self.write_array_size(module, size)?;
}
writeln!(self.out, "; }}")?;
} else {
writeln!(self.out, ";")?;
}
Expand Down
41 changes: 25 additions & 16 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,25 @@ fn should_pack_struct_member(
}

fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = arena[member.ty].inner
{
return true;
match arena[ty].inner {
crate::TypeInner::Struct { ref members, .. } => {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = arena[member.ty].inner
{
return true;
}
}
false
}
crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} => true,
_ => false,
}
false
}

impl crate::AddressSpace {
Expand Down Expand Up @@ -741,16 +748,18 @@ impl<W: Write> Writer<W> {
context: &ExpressionContext,
) -> BackendResult {
let global = &context.module.global_variables[handle];
let members = match context.module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => members,
let (offset, array_ty) = match context.module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => match members.last() {
Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
None => return Err(Error::Validation),
},
crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} => (0, global.ty),
_ => return Err(Error::Validation),
};

let (offset, array_ty) = match members.last() {
Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
None => return Err(Error::Validation),
};

let (size, stride) = match context.module.types[array_ty].inner {
crate::TypeInner::Array { base, stride, .. } => (
context.module.types[base]
Expand Down
3 changes: 2 additions & 1 deletion src/back/spv/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab
},
None => false,
},
_ => false,
// if it's not a structure, let's wrap it to be able to put "Block"
_ => true,
}
}
20 changes: 9 additions & 11 deletions src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,15 @@ impl super::Validator {
true,
)
}
crate::AddressSpace::Handle => (TypeFlags::empty(), true),
crate::AddressSpace::Handle => {
match types[var.ty].inner {
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {}
_ => {
return Err(GlobalVariableError::InvalidType);
}
};
(TypeFlags::empty(), true)
}
crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {
(TypeFlags::DATA | TypeFlags::SIZED, false)
}
Expand All @@ -375,16 +383,6 @@ impl super::Validator {
}
};

let is_handle = var.space == crate::AddressSpace::Handle;
let good_type = match types[var.ty].inner {
crate::TypeInner::Struct { .. } => !is_handle,
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => is_handle,
_ => false,
};
if is_resource && !good_type {
return Err(GlobalVariableError::InvalidType);
}

if !type_info.flags.contains(required_type_flags) {
return Err(GlobalVariableError::MissingTypeFlags {
seen: type_info.flags,
Expand Down
9 changes: 4 additions & 5 deletions tests/in/globals.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ struct Foo {
@group(0) @binding(1)
var<storage> alignment: Foo;

struct Dummy {
arr: array<vec2<f32>>;
};

@group(0) @binding(2)
var<storage> dummy: Dummy;
var<storage> dummy: array<vec2<f32>>;

@group(0) @binding(3)
var<uniform> float_vecs: array<vec4<f32>, 20>;

@stage(compute) @workgroup_size(1)
fn main() {
Expand Down
8 changes: 4 additions & 4 deletions tests/out/glsl/globals.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ layout(std430) readonly buffer Foo_block_0Compute { Foo _group_0_binding_1_cs; }
void main() {
float Foo_1 = 1.0;
bool at = true;
float _e8 = _group_0_binding_1_cs.v1_;
wg[3] = _e8;
float _e13 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e13;
float _e9 = _group_0_binding_1_cs.v1_;
wg[3] = _e9;
float _e14 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e14;
at_1 = 2u;
return;
}
Expand Down
9 changes: 5 additions & 4 deletions tests/out/hlsl/globals.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ groupshared float wg[10];
groupshared uint at_1;
ByteAddressBuffer alignment : register(t1);
ByteAddressBuffer dummy : register(t2);
cbuffer float_vecs : register(b3) { float4 float_vecs[20]; }

[numthreads(1, 1, 1)]
void main()
{
float Foo_1 = 1.0;
bool at = true;

float _expr8 = asfloat(alignment.Load(12));
wg[3] = _expr8;
float _expr13 = asfloat(alignment.Load(0+0));
wg[2] = _expr13;
float _expr9 = asfloat(alignment.Load(12));
wg[3] = _expr9;
float _expr14 = asfloat(alignment.Load(0+0));
wg[2] = _expr14;
at_1 = 2u;
return;
}
12 changes: 6 additions & 6 deletions tests/out/msl/globals.msl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ struct Foo {
float v1_;
};
typedef metal::float2 type_6[1];
struct Dummy {
type_6 arr;
struct type_8 {
metal::float4 inner[20];
};

kernel void main_(
Expand All @@ -26,10 +26,10 @@ kernel void main_(
) {
float Foo_1 = 1.0;
bool at = true;
float _e8 = alignment.v1_;
wg.inner[3] = _e8;
float _e13 = metal::float3(alignment.v3_).x;
wg.inner[2] = _e13;
float _e9 = alignment.v1_;
wg.inner[3] = _e9;
float _e14 = metal::float3(alignment.v3_).x;
wg.inner[2] = _e14;
metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed);
return;
}
141 changes: 77 additions & 64 deletions tests/out/spv/globals.spvasm
Original file line number Diff line number Diff line change
@@ -1,81 +1,94 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 53
; Bound: 61
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %34 "main"
OpExecutionMode %34 LocalSize 1 1 1
OpDecorate %14 ArrayStride 4
OpMemberDecorate %16 0 Offset 0
OpMemberDecorate %16 1 Offset 12
OpDecorate %18 ArrayStride 8
OpMemberDecorate %19 0 Offset 0
OpDecorate %24 NonWritable
OpDecorate %24 DescriptorSet 0
OpDecorate %24 Binding 1
OpDecorate %25 Block
OpMemberDecorate %25 0 Offset 0
OpDecorate %27 NonWritable
OpDecorate %27 DescriptorSet 0
OpDecorate %27 Binding 2
OpDecorate %19 Block
OpEntryPoint GLCompute %40 "main"
OpExecutionMode %40 LocalSize 1 1 1
OpDecorate %15 ArrayStride 4
OpMemberDecorate %17 0 Offset 0
OpMemberDecorate %17 1 Offset 12
OpDecorate %19 ArrayStride 8
OpDecorate %21 ArrayStride 16
OpDecorate %26 NonWritable
OpDecorate %26 DescriptorSet 0
OpDecorate %26 Binding 1
OpDecorate %27 Block
OpMemberDecorate %27 0 Offset 0
OpDecorate %29 NonWritable
OpDecorate %29 DescriptorSet 0
OpDecorate %29 Binding 2
OpDecorate %30 Block
OpMemberDecorate %30 0 Offset 0
OpDecorate %32 DescriptorSet 0
OpDecorate %32 Binding 3
OpDecorate %33 Block
OpMemberDecorate %33 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeBool
%3 = OpConstantTrue %4
%6 = OpTypeInt 32 0
%5 = OpConstant %6 10
%8 = OpTypeInt 32 1
%7 = OpConstant %8 3
%9 = OpConstant %8 2
%10 = OpConstant %6 2
%12 = OpTypeFloat 32
%11 = OpConstant %12 1.0
%13 = OpConstantTrue %4
%14 = OpTypeArray %12 %5
%15 = OpTypeVector %12 3
%16 = OpTypeStruct %15 %12
%17 = OpTypeVector %12 2
%18 = OpTypeRuntimeArray %17
%19 = OpTypeStruct %18
%21 = OpTypePointer Workgroup %14
%20 = OpVariable %21 Workgroup
%23 = OpTypePointer Workgroup %6
%7 = OpConstant %8 20
%9 = OpConstant %8 3
%10 = OpConstant %8 2
%11 = OpConstant %6 2
%13 = OpTypeFloat 32
%12 = OpConstant %13 1.0
%14 = OpConstantTrue %4
%15 = OpTypeArray %13 %5
%16 = OpTypeVector %13 3
%17 = OpTypeStruct %16 %13
%18 = OpTypeVector %13 2
%19 = OpTypeRuntimeArray %18
%20 = OpTypeVector %13 4
%21 = OpTypeArray %20 %7
%23 = OpTypePointer Workgroup %15
%22 = OpVariable %23 Workgroup
%25 = OpTypeStruct %16
%26 = OpTypePointer StorageBuffer %25
%24 = OpVariable %26 StorageBuffer
%28 = OpTypePointer StorageBuffer %19
%27 = OpVariable %28 StorageBuffer
%30 = OpTypePointer Function %12
%32 = OpTypePointer Function %4
%35 = OpTypeFunction %2
%36 = OpTypePointer StorageBuffer %16
%37 = OpConstant %6 0
%40 = OpTypePointer Workgroup %12
%41 = OpTypePointer StorageBuffer %12
%42 = OpConstant %6 1
%45 = OpConstant %6 3
%47 = OpTypePointer StorageBuffer %15
%48 = OpTypePointer StorageBuffer %12
%52 = OpConstant %6 256
%34 = OpFunction %2 None %35
%33 = OpLabel
%29 = OpVariable %30 Function %11
%31 = OpVariable %32 Function %13
%38 = OpAccessChain %36 %24 %37
OpBranch %39
%25 = OpTypePointer Workgroup %6
%24 = OpVariable %25 Workgroup
%27 = OpTypeStruct %17
%28 = OpTypePointer StorageBuffer %27
%26 = OpVariable %28 StorageBuffer
%30 = OpTypeStruct %19
%31 = OpTypePointer StorageBuffer %30
%29 = OpVariable %31 StorageBuffer
%33 = OpTypeStruct %21
%34 = OpTypePointer Uniform %33
%32 = OpVariable %34 Uniform
%36 = OpTypePointer Function %13
%38 = OpTypePointer Function %4
%41 = OpTypeFunction %2
%42 = OpTypePointer StorageBuffer %17
%43 = OpConstant %6 0
%45 = OpTypePointer StorageBuffer %19
%46 = OpTypePointer Uniform %21
%48 = OpTypePointer Workgroup %13
%49 = OpTypePointer StorageBuffer %13
%50 = OpConstant %6 1
%53 = OpConstant %6 3
%55 = OpTypePointer StorageBuffer %16
%56 = OpTypePointer StorageBuffer %13
%60 = OpConstant %6 256
%40 = OpFunction %2 None %41
%39 = OpLabel
%43 = OpAccessChain %41 %38 %42
%44 = OpLoad %12 %43
%46 = OpAccessChain %40 %20 %45
OpStore %46 %44
%49 = OpAccessChain %48 %38 %37 %37
%50 = OpLoad %12 %49
%51 = OpAccessChain %40 %20 %10
OpStore %51 %50
OpAtomicStore %22 %9 %52 %10
%35 = OpVariable %36 Function %12
%37 = OpVariable %38 Function %14
%44 = OpAccessChain %42 %26 %43
OpBranch %47
%47 = OpLabel
%51 = OpAccessChain %49 %44 %50
%52 = OpLoad %13 %51
%54 = OpAccessChain %48 %22 %53
OpStore %54 %52
%57 = OpAccessChain %56 %44 %43 %43
%58 = OpLoad %13 %57
%59 = OpAccessChain %48 %22 %11
OpStore %59 %58
OpAtomicStore %24 %10 %60 %11
OpReturn
OpFunctionEnd
Loading

0 comments on commit a00638f

Please sign in to comment.