Skip to content

Commit

Permalink
Allow non-structure buffer types
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Jan 24, 2022
1 parent 4c91abe commit 508d29a
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 82 deletions.
4 changes: 3 additions & 1 deletion src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, " {}", name)?;
if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
if let crate::StorageClass::Storage { access: _ } = global.class {
// do nothing
} else if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
self.write_array_size(module, size)?;
}

Expand Down
41 changes: 25 additions & 16 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,25 @@ fn should_pack_struct_member(
}

fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = arena[member.ty].inner
{
return true;
match arena[ty].inner {
crate::TypeInner::Struct { ref members, .. } => {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = arena[member.ty].inner
{
return true;
}
}
false
}
crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} => true,
_ => false,
}
false
}

impl crate::StorageClass {
Expand Down Expand Up @@ -741,16 +748,18 @@ impl<W: Write> Writer<W> {
context: &ExpressionContext,
) -> BackendResult {
let global = &context.module.global_variables[handle];
let members = match context.module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => members,
let (offset, array_ty) = match context.module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => match members.last() {
Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
None => return Err(Error::Validation),
},
crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} => (0, global.ty),
_ => return Err(Error::Validation),
};

let (offset, array_ty) = match members.last() {
Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
None => return Err(Error::Validation),
};

let (size, stride) = match context.module.types[array_ty].inner {
crate::TypeInner::Array { base, stride, .. } => (
context.module.types[base]
Expand Down
3 changes: 2 additions & 1 deletion src/back/spv/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab
},
None => false,
},
_ => false,
// if it's not a structure, let's wrap it to be able to put "Block"
_ => true,
}
}
20 changes: 9 additions & 11 deletions src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,15 @@ impl super::Validator {
true,
)
}
crate::StorageClass::Handle => (TypeFlags::empty(), true),
crate::StorageClass::Handle => {
match types[var.ty].inner {
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {}
_ => {
return Err(GlobalVariableError::InvalidType);
}
};
(TypeFlags::empty(), true)
}
crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
(TypeFlags::DATA | TypeFlags::SIZED, false)
}
Expand All @@ -375,16 +383,6 @@ impl super::Validator {
}
};

let is_handle = var.class == crate::StorageClass::Handle;
let good_type = match types[var.ty].inner {
crate::TypeInner::Struct { .. } => !is_handle,
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => is_handle,
_ => false,
};
if is_resource && !good_type {
return Err(GlobalVariableError::InvalidType);
}

if !type_info.flags.contains(required_type_flags) {
return Err(GlobalVariableError::MissingTypeFlags {
seen: type_info.flags,
Expand Down
6 changes: 1 addition & 5 deletions tests/in/globals.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@ struct Foo {
@group(0) @binding(1)
var<storage> alignment: Foo;

struct Dummy {
arr: array<vec2<f32>>;
};

@group(0) @binding(2)
var<storage> dummy: Dummy;
var<storage> dummy: array<vec2<f32>>;

@stage(compute) @workgroup_size(1)
fn main() {
Expand Down
3 changes: 0 additions & 3 deletions tests/out/msl/globals.msl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ struct Foo {
float v1_;
};
typedef metal::float2 type_6[1];
struct Dummy {
type_6 arr;
};

kernel void main_(
threadgroup type_2& wg
Expand Down
81 changes: 41 additions & 40 deletions tests/out/spv/globals.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 53
; Bound: 54
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
Expand All @@ -12,16 +12,16 @@ OpDecorate %14 ArrayStride 4
OpMemberDecorate %16 0 Offset 0
OpMemberDecorate %16 1 Offset 12
OpDecorate %18 ArrayStride 8
OpMemberDecorate %19 0 Offset 0
OpDecorate %24 NonWritable
OpDecorate %24 DescriptorSet 0
OpDecorate %24 Binding 1
OpDecorate %25 Block
OpMemberDecorate %25 0 Offset 0
OpDecorate %27 NonWritable
OpDecorate %27 DescriptorSet 0
OpDecorate %27 Binding 2
OpDecorate %19 Block
OpDecorate %23 NonWritable
OpDecorate %23 DescriptorSet 0
OpDecorate %23 Binding 1
OpDecorate %24 Block
OpMemberDecorate %24 0 Offset 0
OpDecorate %26 NonWritable
OpDecorate %26 DescriptorSet 0
OpDecorate %26 Binding 2
OpDecorate %27 Block
OpMemberDecorate %27 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeBool
%3 = OpConstantTrue %4
Expand All @@ -39,43 +39,44 @@ OpDecorate %19 Block
%16 = OpTypeStruct %15 %12
%17 = OpTypeVector %12 2
%18 = OpTypeRuntimeArray %17
%19 = OpTypeStruct %18
%21 = OpTypePointer Workgroup %14
%20 = OpVariable %21 Workgroup
%23 = OpTypePointer Workgroup %6
%22 = OpVariable %23 Workgroup
%25 = OpTypeStruct %16
%26 = OpTypePointer StorageBuffer %25
%24 = OpVariable %26 StorageBuffer
%28 = OpTypePointer StorageBuffer %19
%27 = OpVariable %28 StorageBuffer
%20 = OpTypePointer Workgroup %14
%19 = OpVariable %20 Workgroup
%22 = OpTypePointer Workgroup %6
%21 = OpVariable %22 Workgroup
%24 = OpTypeStruct %16
%25 = OpTypePointer StorageBuffer %24
%23 = OpVariable %25 StorageBuffer
%27 = OpTypeStruct %18
%28 = OpTypePointer StorageBuffer %27
%26 = OpVariable %28 StorageBuffer
%30 = OpTypePointer Function %12
%32 = OpTypePointer Function %4
%35 = OpTypeFunction %2
%36 = OpTypePointer StorageBuffer %16
%37 = OpConstant %6 0
%40 = OpTypePointer Workgroup %12
%41 = OpTypePointer StorageBuffer %12
%42 = OpConstant %6 1
%45 = OpConstant %6 3
%47 = OpTypePointer StorageBuffer %15
%48 = OpTypePointer StorageBuffer %12
%52 = OpConstant %6 256
%39 = OpTypePointer StorageBuffer %18
%41 = OpTypePointer Workgroup %12
%42 = OpTypePointer StorageBuffer %12
%43 = OpConstant %6 1
%46 = OpConstant %6 3
%48 = OpTypePointer StorageBuffer %15
%49 = OpTypePointer StorageBuffer %12
%53 = OpConstant %6 256
%34 = OpFunction %2 None %35
%33 = OpLabel
%29 = OpVariable %30 Function %11
%31 = OpVariable %32 Function %13
%38 = OpAccessChain %36 %24 %37
OpBranch %39
%39 = OpLabel
%43 = OpAccessChain %41 %38 %42
%44 = OpLoad %12 %43
%46 = OpAccessChain %40 %20 %45
OpStore %46 %44
%49 = OpAccessChain %48 %38 %37 %37
%50 = OpLoad %12 %49
%51 = OpAccessChain %40 %20 %10
OpStore %51 %50
OpAtomicStore %22 %9 %52 %10
%38 = OpAccessChain %36 %23 %37
OpBranch %40
%40 = OpLabel
%44 = OpAccessChain %42 %38 %43
%45 = OpLoad %12 %44
%47 = OpAccessChain %41 %19 %46
OpStore %47 %45
%50 = OpAccessChain %49 %38 %37 %37
%51 = OpLoad %12 %50
%52 = OpAccessChain %41 %19 %10
OpStore %52 %51
OpAtomicStore %21 %9 %53 %10
OpReturn
OpFunctionEnd
6 changes: 1 addition & 5 deletions tests/out/wgsl/globals.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@ struct Foo {
v1_: f32;
};

struct Dummy {
arr: array<vec2<f32>>;
};

let Foo_2: bool = true;

var<workgroup> wg: array<f32,10u>;
var<workgroup> at_1: atomic<u32>;
@group(0) @binding(1)
var<storage> alignment: Foo;
@group(0) @binding(2)
var<storage> dummy: Dummy;
var<storage> dummy: array<vec2<f32>>;

@stage(compute) @workgroup_size(1, 1, 1)
fn main() {
Expand Down

0 comments on commit 508d29a

Please sign in to comment.