Skip to content

Commit

Permalink
[msl-out] wrap arrays in structs so that they can be returned by func…
Browse files Browse the repository at this point in the history
…tions (#764)

* [msl-out] wrap arrays in structs so that they can be returned by functions

* Fix clippy problems

* use a raw array for output fields

* Fix clippy problems

* Apply suggestions

* Remove put_initialization_component

* Check if the array is a constant size

* Don't use the pointer class
  • Loading branch information
expenses authored Apr 28, 2021
1 parent d9171db commit d21dded
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 53 deletions.
110 changes: 63 additions & 47 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::{
const NAMESPACE: &str = "metal";
const INDENT: &str = " ";
const BAKE_PREFIX: &str = "_e";
const WRAPPED_ARRAY_FIELD: &str = "inner";

#[derive(Clone)]
struct Level(usize);
Expand Down Expand Up @@ -81,12 +82,9 @@ impl<'a> Display for TypeContext<'a> {
}
crate::TypeInner::Pointer { base, class } => {
let sub = Self {
arena: self.arena,
names: self.names,
handle: base,
usage: self.usage,
access: self.access,
first_time: false,
..*self
};
let class_name = match class.get_name(self.usage) {
Some(name) => name,
Expand Down Expand Up @@ -125,7 +123,17 @@ impl<'a> Display for TypeContext<'a> {
vector_size_str(size),
)
}
crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => unreachable!(),
crate::TypeInner::Array { base, .. } => {
let sub = Self {
handle: base,
first_time: false,
..*self
};
// Array lengths go at the end of the type definition,
// so just print the element type here.
write!(out, "{}", sub)
}
crate::TypeInner::Struct { .. } => unreachable!(),
crate::TypeInner::Image {
dim,
arrayed,
Expand Down Expand Up @@ -533,39 +541,6 @@ impl<W: Write> Writer<W> {
Ok(())
}

fn put_initialization_component(
&mut self,
component: Handle<crate::Expression>,
context: &ExpressionContext,
) -> Result<(), Error> {
// we can't initialize the array members just like other members,
// we have to unwrap them one level deeper...
let component_res = &context.info[component].ty;
if let crate::TypeInner::Array {
size: crate::ArraySize::Constant(const_handle),
..
} = *component_res.inner_with(&context.module.types)
{
//HACK: we are forcefully duplicating the expression here,
// it would be nice to find a more C++ idiomatic solution for initializing array members
let size = context.module.constants[const_handle]
.to_array_length()
.unwrap();
write!(self.out, "{{")?;
for j in 0..size {
if j != 0 {
write!(self.out, ",")?;
}
self.put_expression(component, context, false)?;
write!(self.out, "[{}]", j)?;
}
write!(self.out, "}}")?;
} else {
self.put_expression(component, context, true)?;
}
Ok(())
}

fn put_expression(
&mut self,
expr_handle: Handle<crate::Expression>,
Expand All @@ -587,7 +562,25 @@ impl<W: Write> Writer<W> {
log::trace!("expression {:?} = {:?}", expr_handle, expression);
match *expression {
crate::Expression::Access { base, index } => {
let accessing_wrapped_array =
match *context.info[base].ty.inner_with(&context.module.types) {
crate::TypeInner::Array { .. } => true,
crate::TypeInner::Pointer {
base: pointer_base, ..
} => match context.module.types[pointer_base].inner {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(_),
..
} => true,
_ => false,
},
_ => false,
};

self.put_expression(base, context, false)?;
if accessing_wrapped_array {
write!(self.out, ".{}", WRAPPED_ARRAY_FIELD)?;
}
write!(self.out, "[")?;
self.put_expression(index, context, true)?;
write!(self.out, "]")?;
Expand Down Expand Up @@ -690,7 +683,7 @@ impl<W: Write> Writer<W> {
if i != 0 {
write!(self.out, ", ")?;
}
self.put_initialization_component(component, context)?;
self.put_expression(component, context, true)?;
}
write!(self.out, "}}")?;
}
Expand Down Expand Up @@ -1121,7 +1114,9 @@ impl<W: Write> Writer<W> {
let comma = if is_first { "" } else { "," };
is_first = false;
let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
// logic similar to `put_initialization_component`
// HACK: we are forcefully deduplicating the expression here
// to convert from a wrapped struct to a raw array, e.g.
// `float gl_ClipDistance1 [[clip_distance]] [1];`.
if let crate::TypeInner::Array {
size: crate::ArraySize::Constant(const_handle),
..
Expand All @@ -1135,7 +1130,11 @@ impl<W: Write> Writer<W> {
if j != 0 {
write!(self.out, ",")?;
}
write!(self.out, "{}.{}[{}]", tmp, name, j)?;
write!(
self.out,
"{}.{}.{}[{}]",
tmp, name, WRAPPED_ARRAY_FIELD, j
)?;
}
write!(self.out, "}}")?;
} else {
Expand Down Expand Up @@ -1345,9 +1344,9 @@ impl<W: Write> Writer<W> {
.unwrap();
write!(self.out, "{}for(int _i=0; _i<{}; ++_i) ", level, size)?;
self.put_expression(pointer, &context.expression, true)?;
write!(self.out, "[_i] = ")?;
write!(self.out, ".{}[_i] = ", WRAPPED_ARRAY_FIELD)?;
self.put_expression(value, &context.expression, true)?;
writeln!(self.out, "[_i];")?;
writeln!(self.out, ".{}[_i];", WRAPPED_ARRAY_FIELD)?;
}
None => {
write!(self.out, "{}", level)?;
Expand Down Expand Up @@ -1468,7 +1467,7 @@ impl<W: Write> Writer<W> {
access: crate::StorageAccess::empty(),
first_time: false,
};
write!(self.out, "typedef {} {}", base_name, name)?;

match size {
crate::ArraySize::Constant(const_handle) => {
let coco = ConstantContext {
Expand All @@ -1477,10 +1476,17 @@ impl<W: Write> Writer<W> {
names: &self.names,
first_time: false,
};
writeln!(self.out, "[{}];", coco)?;

writeln!(self.out, "struct {} {{", name)?;
writeln!(
self.out,
"{}{} {}[{}];",
INDENT, base_name, WRAPPED_ARRAY_FIELD, coco
)?;
writeln!(self.out, "}};")?;
}
crate::ArraySize::Dynamic => {
writeln!(self.out, "[1];")?;
writeln!(self.out, "typedef {} {}[1];", base_name, name)?;
}
}
}
Expand Down Expand Up @@ -1942,17 +1948,27 @@ impl<W: Write> Writer<W> {
names: &self.names,
usage: GlobalUse::empty(),
access: crate::StorageAccess::empty(),
first_time: false,
first_time: true,
};
let binding = binding.ok_or(Error::Validation)?;
if !pipeline_options.allow_point_size
&& *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize)
{
continue;
}
let array_len = match module.types[ty].inner {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(handle),
..
} => module.constants[handle].to_array_length(),
_ => None,
};
let resolved = options.resolve_local_binding(binding, out_mode)?;
write!(self.out, "{}{} {}", INDENT, ty_name, name)?;
resolved.try_fmt_decorated(&mut self.out, "")?;
if let Some(array_len) = array_len {
write!(self.out, " [{}]", array_len)?;
}
writeln!(self.out, ";")?;
}
writeln!(self.out, "}};")?;
Expand Down
6 changes: 4 additions & 2 deletions tests/out/access.msl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <metal_stdlib>
#include <simd/simd.h>

typedef int type3[5];
struct type3 {
int inner[5];
};

struct fooInput {
};
Expand All @@ -11,5 +13,5 @@ struct fooOutput {
vertex fooOutput foo(
metal::uint vi [[vertex_id]]
) {
return fooOutput { static_cast<float4>(int4(type3 {1, 2, 3, 4, 5}[vi])) };
return fooOutput { static_cast<float4>(int4(type3 {1, 2, 3, 4, 5}.inner[vi])) };
}
10 changes: 6 additions & 4 deletions tests/out/quad-vert.msl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <metal_stdlib>
#include <simd/simd.h>

typedef float type6[1u];
struct type6 {
float inner[1u];
};
struct gl_PerVertex {
metal::float4 gl_Position;
float gl_PointSize;
Expand Down Expand Up @@ -35,7 +37,7 @@ struct main2Output {
metal::float2 member [[user(loc0), center_perspective]];
metal::float4 gl_Position1 [[position]];
float gl_PointSize1 [[point_size]];
type6 gl_ClipDistance1 [[clip_distance]];
float gl_ClipDistance1 [[clip_distance]] [1];
};
vertex main2Output main2(
main2Input varyings [[stage_in]]
Expand All @@ -49,6 +51,6 @@ vertex main2Output main2(
a_uv = a_uv1;
a_pos = a_pos1;
main1(v_uv, a_uv, _, a_pos);
const auto _tmp = type10 {v_uv, _.gl_Position, _.gl_PointSize, {_.gl_ClipDistance[0]}};
return main2Output { _tmp.member, _tmp.gl_Position1, _tmp.gl_PointSize1, {_tmp.gl_ClipDistance1[0]} };
const auto _tmp = type10 {v_uv, _.gl_Position, _.gl_PointSize, _.gl_ClipDistance};
return main2Output { _tmp.member, _tmp.gl_Position1, _tmp.gl_PointSize1, {_tmp.gl_ClipDistance1.inner[0]} };
}

0 comments on commit d21dded

Please sign in to comment.