Skip to content

Commit

Permalink
add prune support
Browse files Browse the repository at this point in the history
  • Loading branch information
robtfm committed May 17, 2024
1 parent 4c12bce commit 47eff7c
Showing 1 changed file with 57 additions and 3 deletions.
60 changes: 57 additions & 3 deletions src/prune/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ impl FunctionReq {
committed: *committed,
}
}
Expression::Override(_) => expr.clone(),
Expression::SubgroupBallotResult => expr.clone(),
Expression::SubgroupOperationResult { .. } => expr.clone(),
}
}

Expand Down Expand Up @@ -1372,6 +1375,15 @@ impl<'a> Pruner<'a> {
} => {
self.add_expression(function, func_req, context, *query, &PartReq::All);
}
Expression::Override(_) => {
// we don't prune overrides, so nothing to do
}
Expression::SubgroupBallotResult => {
// nothing, handled by the statement
}
Expression::SubgroupOperationResult { .. } => {
// nothing, handled by the statement
}
}

func_req.exprs_required.insert(h_expr, part.clone());
Expand Down Expand Up @@ -1676,6 +1688,48 @@ impl<'a> Pruner<'a> {
}
RayQuery(required)
}
Statement::SubgroupBallot { result, predicate } => {
let var_ref = Self::resolve_var(function, *result, Vec::default());
let required = self.store_required(context, &var_ref).is_some();
if required {
if let Some(predicate) = predicate {
self.add_expression(function, func_req, context, *predicate, &PartReq::All);
}
}
RayQuery(required)
}
Statement::SubgroupGather {
mode,
argument,
result,
} => {
let var_ref = Self::resolve_var(function, *result, Vec::default());
let required = self.store_required(context, &var_ref).is_some();
if required {
match mode {
naga::GatherMode::BroadcastFirst => (),
naga::GatherMode::Broadcast(h_src)
| naga::GatherMode::Shuffle(h_src)
| naga::GatherMode::ShuffleDown(h_src)
| naga::GatherMode::ShuffleUp(h_src)
| naga::GatherMode::ShuffleXor(h_src) => {
self.add_expression(function, func_req, context, *h_src, &PartReq::All)
}
}
self.add_expression(function, func_req, context, *argument, &PartReq::All);
}
RayQuery(required)
}
Statement::SubgroupCollectiveOperation {
argument, result, ..
} => {
let var_ref = Self::resolve_var(function, *result, Vec::default());
let required = self.store_required(context, &var_ref).is_some();
if required {
self.add_expression(function, func_req, context, *argument, &PartReq::All);
}
RayQuery(required)
}
}
}

Expand Down Expand Up @@ -1802,9 +1856,9 @@ impl<'a> Pruner<'a> {
let mut derived = DerivedModule::default();
derived.set_shader_source(self.module, 0);

// just copy all the constants for now, so we can copy const handles as well
for (h_cexpr, _) in self.module.const_expressions.iter() {
derived.import_const_expression(h_cexpr);
// just copy all the (pipeline + normal) constants for now, so we can copy const handles as well
for (h_cexpr, _) in self.module.global_expressions.iter() {
derived.import_global_expression(h_cexpr);
}

for (h_f, f) in self.module.functions.iter() {
Expand Down

0 comments on commit 47eff7c

Please sign in to comment.