diff --git a/src/prune/mod.rs b/src/prune/mod.rs index 9b0cc4d..94c5bd8 100644 --- a/src/prune/mod.rs +++ b/src/prune/mod.rs @@ -248,6 +248,9 @@ impl FunctionReq { committed: *committed, } } + Expression::Override(_) => expr.clone(), + Expression::SubgroupBallotResult => expr.clone(), + Expression::SubgroupOperationResult { .. } => expr.clone(), } } @@ -1372,6 +1375,15 @@ impl<'a> Pruner<'a> { } => { self.add_expression(function, func_req, context, *query, &PartReq::All); } + Expression::Override(_) => { + // we don't prune overrides, so nothing to do + } + Expression::SubgroupBallotResult => { + // nothing, handled by the statement + } + Expression::SubgroupOperationResult { .. } => { + // nothing, handled by the statement + } } func_req.exprs_required.insert(h_expr, part.clone()); @@ -1676,6 +1688,48 @@ impl<'a> Pruner<'a> { } RayQuery(required) } + Statement::SubgroupBallot { result, predicate } => { + let var_ref = Self::resolve_var(function, *result, Vec::default()); + let required = self.store_required(context, &var_ref).is_some(); + if required { + if let Some(predicate) = predicate { + self.add_expression(function, func_req, context, *predicate, &PartReq::All); + } + } + RayQuery(required) + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + let var_ref = Self::resolve_var(function, *result, Vec::default()); + let required = self.store_required(context, &var_ref).is_some(); + if required { + match mode { + naga::GatherMode::BroadcastFirst => (), + naga::GatherMode::Broadcast(h_src) + | naga::GatherMode::Shuffle(h_src) + | naga::GatherMode::ShuffleDown(h_src) + | naga::GatherMode::ShuffleUp(h_src) + | naga::GatherMode::ShuffleXor(h_src) => { + self.add_expression(function, func_req, context, *h_src, &PartReq::All) + } + } + self.add_expression(function, func_req, context, *argument, &PartReq::All); + } + RayQuery(required) + } + Statement::SubgroupCollectiveOperation { + argument, result, .. + } => { + let var_ref = Self::resolve_var(function, *result, Vec::default()); + let required = self.store_required(context, &var_ref).is_some(); + if required { + self.add_expression(function, func_req, context, *argument, &PartReq::All); + } + RayQuery(required) + } } } @@ -1802,9 +1856,9 @@ impl<'a> Pruner<'a> { let mut derived = DerivedModule::default(); derived.set_shader_source(self.module, 0); - // just copy all the constants for now, so we can copy const handles as well - for (h_cexpr, _) in self.module.const_expressions.iter() { - derived.import_const_expression(h_cexpr); + // just copy all the (pipeline + normal) constants for now, so we can copy const handles as well + for (h_cexpr, _) in self.module.global_expressions.iter() { + derived.import_global_expression(h_cexpr); } for (h_f, f) in self.module.functions.iter() {