Skip to content

Commit

Permalink
Binding arrays play nice with bounds checks (#1855)
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald authored Apr 25, 2022
1 parent ad28396 commit 1aa9154
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 257 deletions.
23 changes: 16 additions & 7 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1270,13 +1270,14 @@ impl<W: Write> Writer<W> {
let expression = &context.function.expressions[expr_handle];
log::trace!("expression {:?} = {:?}", expr_handle, expression);
match *expression {
crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => {
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.
// Since `put_bounds_checks` and `put_access_chain` handle an entire
// access chain at a time, recursing back through `put_expression` only
// for index expressions and the base object, we will never see intermediate
// `Access` or `AccessIndex` expressions here.
let policy = context.choose_bounds_check_policy(expr_handle);
let policy = context.choose_bounds_check_policy(base);
if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
&& self.put_bounds_checks(
expr_handle,
Expand Down Expand Up @@ -3339,11 +3340,19 @@ impl<W: Write> Writer<W> {
}
if let Some(ref br) = var.binding {
let good = match options.per_stage_map[ep.stage].resources.get(br) {
Some(target) => match module.types[var.ty].inner {
crate::TypeInner::Image { .. } => target.texture.is_some(),
crate::TypeInner::Sampler { .. } => target.sampler.is_some(),
_ => target.buffer.is_some(),
},
Some(target) => {
let binding_ty = match module.types[var.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
&module.types[base].inner
}
ref ty => ty,
};
match *binding_ty {
crate::TypeInner::Image { .. } => target.texture.is_some(),
crate::TypeInner::Sampler { .. } => target.sampler.is_some(),
_ => target.buffer.is_some(),
}
}
None => false,
};
if !good {
Expand Down
21 changes: 16 additions & 5 deletions src/proc/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ pub struct BoundsCheckPolicies {
/// [`ImageStore`]: crate::Statement::ImageStore
#[cfg_attr(feature = "deserialize", serde(default))]
pub image: BoundsCheckPolicy,

/// How should the generated code handle binding array indexes that are out of bounds.
#[cfg_attr(feature = "deserialize", serde(default))]
pub binding_array: BoundsCheckPolicy,
}

/// The default `BoundsCheckPolicy` is `Unchecked`.
Expand All @@ -127,20 +131,27 @@ impl Default for BoundsCheckPolicy {
}

impl BoundsCheckPolicies {
/// Determine which policy applies to `access`.
/// Determine which policy applies to `base`.
///
/// `access` is a subtree of `Access` and `AccessIndex` expressions,
/// operating either on a pointer to a value, or on a value directly.
/// `base` is the "base" expression (the expression being indexed) of a `Access`
/// and `AccessIndex` expression. This is either a pointer, a value, being directly
/// indexed, or a binding array.
///
/// See the documentation for [`BoundsCheckPolicy`] for details about
/// when each policy applies.
pub fn choose_policy(
&self,
access: Handle<crate::Expression>,
base: Handle<crate::Expression>,
types: &UniqueArena<crate::Type>,
info: &valid::FunctionInfo,
) -> BoundsCheckPolicy {
match info[access].ty.inner_with(types).pointer_space() {
let ty = info[base].ty.inner_with(types);

if let crate::TypeInner::BindingArray { .. } = *ty {
return self.binding_array;
}

match ty.pointer_space() {
Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => {
self.buffer
}
Expand Down
5 changes: 5 additions & 0 deletions tests/in/binding-arrays.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@
binding_map: {
(group: 0, binding: 0): (binding_array_size: Some(10)),
},
),
bounds_check_policies: (
index: ReadZeroSkipWrite,
buffer: ReadZeroSkipWrite,
image: ReadZeroSkipWrite,
)
)
24 changes: 18 additions & 6 deletions tests/out/msl/binding-arrays.msl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

using metal::uint;

struct DefaultConstructible {
template<typename T>
operator T() && {
return T {};
}
};
struct UniformIndex {
uint index;
};
Expand Down Expand Up @@ -65,13 +71,13 @@ fragment main_Output main_(
metal::float4 _e75 = texture_array_depth[non_uniform_index].gather_compare(samp_comp[non_uniform_index], uv, 0.0);
v4_ = _e71 + _e75;
metal::float4 _e77 = v4_;
metal::float4 _e81 = texture_array_unbounded[0].read(metal::uint2(pix), 0);
metal::float4 _e81 = (uint(0) < texture_array_unbounded[0].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[0].get_width(0), texture_array_unbounded[0].get_height(0))) ? texture_array_unbounded[0].read(metal::uint2(pix), 0): DefaultConstructible());
v4_ = _e77 + _e81;
metal::float4 _e83 = v4_;
metal::float4 _e86 = texture_array_unbounded[uniform_index].read(metal::uint2(pix), 0);
metal::float4 _e86 = (uint(0) < texture_array_unbounded[uniform_index].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[uniform_index].get_width(0), texture_array_unbounded[uniform_index].get_height(0))) ? texture_array_unbounded[uniform_index].read(metal::uint2(pix), 0): DefaultConstructible());
v4_ = _e83 + _e86;
metal::float4 _e88 = v4_;
metal::float4 _e91 = texture_array_unbounded[non_uniform_index].read(metal::uint2(pix), 0);
metal::float4 _e91 = (uint(0) < texture_array_unbounded[non_uniform_index].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[non_uniform_index].get_width(0), texture_array_unbounded[non_uniform_index].get_height(0))) ? texture_array_unbounded[non_uniform_index].read(metal::uint2(pix), 0): DefaultConstructible());
v4_ = _e88 + _e91;
int _e93 = i1_;
i1_ = _e93 + int(texture_array_2darray[0].get_array_size());
Expand Down Expand Up @@ -146,11 +152,17 @@ fragment main_Output main_(
metal::float4 _e244 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv, metal::level(0.0));
v4_ = _e240 + _e244;
metal::float4 _e248 = v4_;
texture_array_storage[0].write(_e248, metal::uint2(pix));
if (metal::all(metal::uint2(pix) < metal::uint2(texture_array_storage[0].get_width(), texture_array_storage[0].get_height()))) {
texture_array_storage[0].write(_e248, metal::uint2(pix));
}
metal::float4 _e250 = v4_;
texture_array_storage[uniform_index].write(_e250, metal::uint2(pix));
if (metal::all(metal::uint2(pix) < metal::uint2(texture_array_storage[uniform_index].get_width(), texture_array_storage[uniform_index].get_height()))) {
texture_array_storage[uniform_index].write(_e250, metal::uint2(pix));
}
metal::float4 _e252 = v4_;
texture_array_storage[non_uniform_index].write(_e252, metal::uint2(pix));
if (metal::all(metal::uint2(pix) < metal::uint2(texture_array_storage[non_uniform_index].get_width(), texture_array_storage[non_uniform_index].get_height()))) {
texture_array_storage[non_uniform_index].write(_e252, metal::uint2(pix));
}
metal::int2 _e253 = i2_;
int _e254 = i1_;
metal::float2 v2_ = static_cast<metal::float2>(_e253 + metal::int2(_e254));
Expand Down
Loading

0 comments on commit 1aa9154

Please sign in to comment.