diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index c4734addd8..632e6cc6a7 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -229,6 +229,7 @@ impl<'tcx> ConvSpirvType<'tcx> for CastTarget { field_types: args, field_offsets, field_names: None, + is_block: false, } .def(span, cx) } @@ -340,6 +341,7 @@ fn trans_type_impl<'tcx>( field_types: Vec::new(), field_offsets: Vec::new(), field_names: None, + is_block: false, } .def(span, cx), Abi::Scalar(ref scalar) => trans_scalar(cx, span, ty, scalar, None, is_immediate), @@ -359,6 +361,7 @@ fn trans_type_impl<'tcx>( field_types: vec![one_spirv, two_spirv], field_offsets: vec![one_offset, two_offset], field_names: None, + is_block: false, } .def(span, cx) } @@ -582,6 +585,20 @@ fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Optio None } +/// Handles `#[spirv(block)]`. Note this is only called in the scalar translation code, because this is only +/// used for spooky builtin stuff, and we pinky promise to never have more than one pointer field in one of these. +// TODO: Enforce this is only used in spirv-std. +fn get_is_block_decorated<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> bool { + if let TyKind::Adt(adt, _substs) = ty.ty.kind() { + for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) { + if let SpirvAttribute::Block = attr { + return true; + } + } + } + false +} + fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word { match ty.fields { FieldsShape::Primitive => cx.tcx.sess.fatal(&format!( @@ -618,6 +635,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx> field_types: Vec::new(), field_offsets: Vec::new(), field_names: None, + is_block: false, } .def(span, cx) } else { @@ -711,6 +729,7 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) - } }; } + let is_block = get_is_block_decorated(cx, ty); SpirvType::Adt { name, size, @@ -718,6 +737,7 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) - field_types, field_offsets, field_names: Some(field_names), + is_block, } .def(span, cx) } diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index f20d09617d..f2711ae7e7 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -1225,8 +1225,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { - let val_pointee = match self.lookup_type(val.ty) { - SpirvType::Pointer { pointee, .. } => pointee, + let (storage_class, val_pointee) = match self.lookup_type(val.ty) { + SpirvType::Pointer { + storage_class, + pointee, + } => (storage_class, pointee), other => self.fatal(&format!( "pointercast called on non-pointer source type: {:?}", other @@ -1242,6 +1245,20 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { if val.ty == dest_ty { val } else if let Some(indices) = self.try_pointercast_via_gep(val_pointee, dest_pointee) { + let dest_ty = if self + .really_unsafe_ignore_bitcasts + .borrow() + .contains(&self.current_fn) + { + SpirvType::Pointer { + storage_class, + pointee: dest_pointee, + } + // TODO: Get actual span here + .def(Span::default(), self) + } else { + dest_ty + }; let indices = indices .into_iter() .map(|idx| self.constant_u32(self.span(), idx).def(self)) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index ae6f9773a9..9e68c8352e 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -187,6 +187,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { field_types, field_offsets, field_names: None, + is_block: false, } .def(DUMMY_SP, self); self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect()) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs index 701984edc6..3672443c17 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs @@ -150,6 +150,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { field_types: els.to_vec(), field_offsets, field_names: None, + is_block: false, } .def(DUMMY_SP, self) } diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index e13cbe6cf0..7569569c26 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -34,6 +34,7 @@ pub enum SpirvType { field_types: Vec, field_offsets: Vec, field_names: Option>, + is_block: bool, }, Opaque { name: String, @@ -126,6 +127,7 @@ impl SpirvType { ref field_types, ref field_offsets, ref field_names, + is_block, } => { let mut emit = cx.emit_global(); // Ensure a unique struct is emitted each time, due to possibly having different OpMemberDecorates @@ -146,6 +148,9 @@ impl SpirvType { ); } } + if is_block { + emit.decorate(id, Decoration::Block, None); + } if let Some(field_names) = field_names { for (index, field_name) in field_names.iter().enumerate() { emit.member_name(result, index as u32, field_name); @@ -344,6 +349,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> { ref field_types, ref field_offsets, ref field_names, + is_block, } => { let fields = field_types .iter() @@ -357,6 +363,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> { .field("field_types", &fields) .field("field_offsets", field_offsets) .field("field_names", field_names) + .field("is_block", &is_block) .finish() } SpirvType::Opaque { ref name } => f @@ -485,6 +492,7 @@ impl SpirvTypePrinter<'_, '_> { ref field_types, field_offsets: _, ref field_names, + is_block: _, } => { write!(f, "struct {} {{ ", name)?; for (index, &field) in field_types.iter().enumerate() { diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index 45dd86c529..63af6518a8 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -336,6 +336,7 @@ impl Symbols { SpirvAttribute::ReallyUnsafeIgnoreBitcasts, ), ("sampler", SpirvAttribute::Sampler), + ("block", SpirvAttribute::Block), ] .iter() .cloned(); @@ -437,6 +438,7 @@ impl From for Entry { pub enum SpirvAttribute { Builtin(BuiltIn), StorageClass(StorageClass), + Block, Entry(Entry), DescriptorSet(u32), Binding(u32), diff --git a/crates/spirv-std/src/storage_class.rs b/crates/spirv-std/src/storage_class.rs index 077d310a31..4fc5740978 100644 --- a/crates/spirv-std/src/storage_class.rs +++ b/crates/spirv-std/src/storage_class.rs @@ -48,6 +48,54 @@ macro_rules! storage_class { } }; + // Interior Block + ($(#[$($meta:meta)+])* block $block:ident storage_class $name:ident ; $($tt:tt)*) => { + + #[spirv(block)] + #[allow(unused_attributes)] + pub struct $block { + value: T + } + + $(#[$($meta)+])* + #[allow(unused_attributes)] + pub struct $name<'block, T> { + block: &'block mut $block , + } + + impl $name<'_, T> { + /// Load the value into memory. + #[inline] + #[allow(unused_attributes)] + #[spirv(really_unsafe_ignore_bitcasts)] + pub fn load(&self) -> T { + self.block.value + } + } + + storage_class!($($tt)*); + }; + + // Methods available on writeable storage classes. + ($(#[$($meta:meta)+])* writeable block $block:ident storage_class $name:ident $($tt:tt)+) => { + storage_class!($(#[$($meta)+])* block $block storage_class $name $($tt)+); + + impl $name<'_, T> { + /// Store the value in storage. + #[inline] + #[allow(unused_attributes)] + #[spirv(really_unsafe_ignore_bitcasts)] + pub fn store(&mut self, v: T) { + self.block.value = v; + } + + /// A convenience function to load a value into memory and store it. + pub fn then(&mut self, then: impl FnOnce(T) -> T) { + self.store((then)(self.load())); + } + } + }; + (;) => {}; () => {}; } @@ -112,7 +160,7 @@ storage_class! { /// Intended to contain a small bank of values pushed from the client API. /// Variables declared with this storage class are read-only, and must not /// have initializers. - #[spirv(push_constant)] storage_class PushConstant; + #[spirv(push_constant)] block PushConstantBlock storage_class PushConstant; /// Atomic counter-specific memory. /// @@ -131,7 +179,7 @@ storage_class! { /// /// Shared externally, readable and writable, visible across all functions /// in all invocations in all work groups. - #[spirv(storage_buffer)] writeable storage_class StorageBuffer; + #[spirv(storage_buffer)] writeable block StorageBufferBlock storage_class StorageBuffer; /// Used for storing arbitrary data associated with a ray to pass /// to callables. (Requires `SPV_KHR_ray_tracing` extension)