Skip to content

Commit

Permalink
msl: fix packed vec access (#1634)
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark authored Dec 28, 2021
1 parent 5a26606 commit 2738ad8
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 96 deletions.
111 changes: 65 additions & 46 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ impl<'a> Display for TypeContext<'a> {
// work around Metal toolchain bug with `uint` typedef
crate::ScalarKind::Uint => write!(out, "{}::uint", NAMESPACE),
_ => {
let kind_str = scalar_kind_string(kind);
let kind_str = kind.to_msl_name();
write!(out, "{}", kind_str)
}
}
}
crate::TypeInner::Atomic { kind, .. } => {
write!(out, "{}::atomic_{}", NAMESPACE, scalar_kind_string(kind))
write!(out, "{}::atomic_{}", NAMESPACE, kind.to_msl_name())
}
crate::TypeInner::Vector { size, kind, .. } => {
write!(
out,
"{}::{}{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size),
)
}
Expand All @@ -69,7 +69,7 @@ impl<'a> Display for TypeContext<'a> {
out,
"{}::{}{}x{}",
NAMESPACE,
scalar_kind_string(crate::ScalarKind::Float),
crate::ScalarKind::Float.to_msl_name(),
back::vector_size_str(columns),
back::vector_size_str(rows),
)
Expand All @@ -96,7 +96,7 @@ impl<'a> Display for TypeContext<'a> {
Some(name) => name,
None => return Ok(()),
};
write!(out, "{} {}&", class_name, scalar_kind_string(kind),)
write!(out, "{} {}&", class_name, kind.to_msl_name(),)
}
crate::TypeInner::ValuePointer {
size: Some(size),
Expand All @@ -113,7 +113,7 @@ impl<'a> Display for TypeContext<'a> {
"{} {}::{}{}&",
class_name,
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size),
)
}
Expand Down Expand Up @@ -178,7 +178,7 @@ impl<'a> Display for TypeContext<'a> {
("texture", "", format.into(), access)
}
};
let base_name = scalar_kind_string(kind);
let base_name = kind.to_msl_name();
let array_str = if arrayed { "_array" } else { "" };
write!(
out,
Expand Down Expand Up @@ -316,12 +316,14 @@ pub struct Writer<W> {
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
}

fn scalar_kind_string(kind: crate::ScalarKind) -> &'static str {
match kind {
crate::ScalarKind::Float => "float",
crate::ScalarKind::Sint => "int",
crate::ScalarKind::Uint => "uint",
crate::ScalarKind::Bool => "bool",
impl crate::ScalarKind {
fn to_msl_name(self) -> &'static str {
match self {
Self::Float => "float",
Self::Sint => "int",
Self::Uint => "uint",
Self::Bool => "bool",
}
}
}

Expand Down Expand Up @@ -481,6 +483,29 @@ impl<'a> ExpressionContext<'a> {
) -> Option<index::IndexableLength> {
index::access_needs_check(base, index, self.module, self.function, self.info)
}

// Because packed vectors such as `packed_float3` cannot be directly loaded,
// we convert them to unpacked vectors like `float3` on load.
fn get_packed_vec_kind(
&self,
expr_handle: Handle<crate::Expression>,
) -> Option<crate::ScalarKind> {
match self.function.expressions[expr_handle] {
crate::Expression::AccessIndex { base, index } => {
let ty = match *self.resolve_type(base) {
crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
ref ty => ty,
};
match *ty {
crate::TypeInner::Struct {
ref members, span, ..
} => should_pack_struct_member(members, span, index as usize, self.module),
_ => None,
}
}
_ => None,
}
}
}

struct StatementContext<'a> {
Expand Down Expand Up @@ -652,15 +677,15 @@ impl<W: Write> Writer<W> {
) -> BackendResult {
match context.module.types[ty].inner {
crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => {
write!(self.out, "{}", scalar_kind_string(kind))?;
write!(self.out, "{}", kind.to_msl_name())?;
self.put_call_parameters(components.iter().cloned(), context)?;
}
crate::TypeInner::Vector { size, kind, .. } => {
write!(
self.out,
"{}::{}{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size)
)?;
self.put_call_parameters(components.iter().cloned(), context)?;
Expand All @@ -671,7 +696,7 @@ impl<W: Write> Writer<W> {
self.out,
"{}::{}{}x{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(columns),
back::vector_size_str(rows)
)?;
Expand Down Expand Up @@ -845,7 +870,7 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Scalar { kind, .. } => kind,
_ => return Err(Error::Validation),
};
let scalar = scalar_kind_string(scalar_kind);
let scalar = scalar_kind.to_msl_name();
let size = back::vector_size_str(size);

write!(self.out, "{}::{}{}(", NAMESPACE, scalar, size)?;
Expand Down Expand Up @@ -1246,7 +1271,7 @@ impl<W: Write> Writer<W> {
kind,
convert,
} => {
let scalar = scalar_kind_string(kind);
let scalar = kind.to_msl_name();
let (src_kind, src_width) = match *context.resolve_type(expr) {
crate::TypeInner::Scalar { kind, width }
| crate::TypeInner::Vector { kind, width, .. } => (kind, width),
Expand Down Expand Up @@ -1487,7 +1512,15 @@ impl<W: Write> Writer<W> {
write!(self.out, ".{}", name)?;
}
crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
self.put_access_chain(base, policy, context)?;
let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(base);
//Note: this doesn't work for left-hand side
if let Some(scalar_kind) = wrap_packed_vec_scalar_kind {
write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
self.put_access_chain(base, policy, context)?;
write!(self.out, ")")?;
} else {
self.put_access_chain(base, policy, context)?;
}
write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
}
_ => {
Expand Down Expand Up @@ -1614,23 +1647,7 @@ impl<W: Write> Writer<W> {
policy: index::BoundsCheckPolicy,
context: &ExpressionContext,
) -> BackendResult {
// Because packed vectors such as `packed_float3` cannot be directly multipied by
// matrices, we convert them to unpacked vectors like `float3` on load.
let wrap_packed_vec_scalar_kind = match context.function.expressions[pointer] {
crate::Expression::AccessIndex { base, index } => {
let ty = match *context.resolve_type(base) {
crate::TypeInner::Pointer { base, .. } => &context.module.types[base].inner,
ref ty => ty,
};
match *ty {
crate::TypeInner::Struct {
ref members, span, ..
} => should_pack_struct_member(members, span, index as usize, context.module),
_ => None,
}
}
_ => None,
};
let wrap_packed_vec_scalar_kind = context.get_packed_vec_kind(pointer);
let is_atomic = match *context.resolve_type(pointer) {
crate::TypeInner::Pointer { base, .. } => match context.module.types[base].inner {
crate::TypeInner::Atomic { .. } => true,
Expand All @@ -1640,12 +1657,7 @@ impl<W: Write> Writer<W> {
};

if let Some(scalar_kind) = wrap_packed_vec_scalar_kind {
write!(
self.out,
"{}::{}3(",
NAMESPACE,
scalar_kind_string(scalar_kind)
)?;
write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ")")?;
} else if is_atomic {
Expand Down Expand Up @@ -1761,15 +1773,22 @@ impl<W: Write> Writer<W> {
};
write!(self.out, "{}", ty_name)?;
}
TypeResolution::Value(crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
}) => {
// work around Metal toolchain bug with `uint` typedef
write!(self.out, "{}::uint", NAMESPACE)?;
}
TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => {
write!(self.out, "{}", scalar_kind_string(kind))?;
write!(self.out, "{}", kind.to_msl_name())?;
}
TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => {
write!(
self.out,
"{}::{}{}",
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
back::vector_size_str(size)
)?;
}
Expand All @@ -1778,7 +1797,7 @@ impl<W: Write> Writer<W> {
self.out,
"{}::{}{}x{}",
NAMESPACE,
scalar_kind_string(crate::ScalarKind::Float),
crate::ScalarKind::Float.to_msl_name(),
back::vector_size_str(columns),
back::vector_size_str(rows),
)?;
Expand Down Expand Up @@ -2360,7 +2379,7 @@ impl<W: Write> Writer<W> {
"{}{}::packed_{}3 {};",
back::INDENT,
NAMESPACE,
scalar_kind_string(kind),
kind.to_msl_name(),
member_name
)?;
}
Expand Down
1 change: 1 addition & 0 deletions tests/in/globals.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var<storage> alignment: Foo;
[[stage(compute), workgroup_size(1)]]
fn main() {
wg[3] = alignment.v1;
wg[2] = alignment.v3.x;
atomicStore(&at, 2u);

// Valid, Foo and at is in function scope
Expand Down
2 changes: 2 additions & 0 deletions tests/out/glsl/globals.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ void main() {
bool at = true;
float _e7 = _group_0_binding_1_cs.v1_;
wg[3] = _e7;
float _e12 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e12;
at_1 = 2u;
return;
}
Expand Down
2 changes: 2 additions & 0 deletions tests/out/hlsl/globals.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ void main()

float _expr7 = asfloat(alignment.Load(12));
wg[3] = _expr7;
float _expr12 = asfloat(alignment.Load(0+0));
wg[2] = _expr12;
at_1 = 2u;
return;
}
2 changes: 1 addition & 1 deletion tests/out/msl/boids.msl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ kernel void main_(
metal::float2 pos;
metal::float2 vel;
metal::uint i = 0u;
uint index = global_invocation_id.x;
metal::uint index = global_invocation_id.x;
if (index >= NUM_PARTICLES) {
return;
}
Expand Down
2 changes: 2 additions & 0 deletions tests/out/msl/globals.msl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ kernel void main_(
bool at = true;
float _e7 = alignment.v1_;
wg.inner[3] = _e7;
float _e12 = metal::float3(alignment.v3_).x;
wg.inner[2] = _e12;
metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed);
return;
}
2 changes: 1 addition & 1 deletion tests/out/msl/shadow.msl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fragment fs_mainOutput fs_main(
}
loop_init = false;
metal::uint _e12 = i;
uint _e15 = u_globals.num_lights.x;
metal::uint _e15 = u_globals.num_lights.x;
if (_e12 >= metal::min(_e15, c_max_lights)) {
break;
}
Expand Down
Loading

0 comments on commit 2738ad8

Please sign in to comment.