Skip to content

Commit

Permalink
Make clamp_weights pass mandatory
Browse files Browse the repository at this point in the history
  • Loading branch information
BramOtte committed Dec 11, 2023
1 parent 5f9d510 commit 9d7c5f4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
18 changes: 10 additions & 8 deletions crates/core/src/redpiler/backend/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@ struct ForwardLink {
}

impl ForwardLink {
pub fn new(id: NodeId, side: bool, mut ss: u8) -> Self {
pub fn new(id: NodeId, side: bool, ss: u8) -> Self {
assert!(id.index() < (1 << 27));
if ss >= 16 {
ss = 15;
}
// the clamp_weights compile pass should ensure ss < 16
assert!(ss < 16);
Self {
data: (id.index() as u32) << 5 | if side { 1 << 4 } else { 0 } | ss as u32,
}
Expand Down Expand Up @@ -183,9 +182,9 @@ impl Node {
stats: &mut FinalGraphStats,
) -> Self {
let node = &graph[node_idx];

const MAX_INPUTS: usize = 255;

let mut default_input_count = 0;
let mut side_input_count = 0;

Expand All @@ -199,11 +198,14 @@ impl Node {
match weight.ty {
LinkType::Default => {
if default_input_count >= MAX_INPUTS {
panic!("Exceeded the maximum number of default inputs {}", MAX_INPUTS);
panic!(
"Exceeded the maximum number of default inputs {}",
MAX_INPUTS
);
}
default_input_count += 1;
default_inputs[ss as usize] += 1;
},
}
LinkType::Side => {
if side_input_count >= MAX_INPUTS {
panic!("Exceeded the maximum number of side inputs {}", MAX_INPUTS);
Expand Down
5 changes: 5 additions & 0 deletions crates/core/src/redpiler/passes/clamp_weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@ impl<W: World> Pass<W> for ClampWeights {
fn run_pass(&self, graph: &mut CompileGraph, _: &CompilerOptions, _: &CompilerInput<'_, W>) {
graph.retain_edges(|g, edge| g[edge].ss < 15);
}

fn should_run(&self, _: &CompilerOptions) -> bool {
// Mandatory
true
}
}

0 comments on commit 9d7c5f4

Please sign in to comment.