Skip to content

Commit

Permalink
fix again nnef binary upcast
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jun 25, 2024
1 parent ab89075 commit 3868b0e
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions nnef/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,24 @@ impl Registry {

// mitigation of nnef "scalar" type mismatch with tract-core more
// strict types
if (!a_dt.is_quantized() || !b_dt.is_quantized()) && a_dt != b_dt {
let operating_dt = if a_dt == TDim::datum_type() || b_dt == TDim::datum_type() {
TDim::datum_type()
if !a_dt.is_quantized() || !b_dt.is_quantized() {
let operating_dt = if a_dt == b_dt
&& bin.1.operating_datum_type(a_dt, b_dt).map(|it| it == a_dt).unwrap_or(false)
{
a_dt
} else if a_dt == TDim::datum_type() || b_dt == TDim::datum_type() {
bin.1.operating_datum_type(a_dt, b_dt)?
} else if builder.model.node(a.node).op_is::<tract_core::ops::konst::Const>() {
b_dt
} else if builder.model.node(b.node).op_is::<tract_core::ops::konst::Const>() {
a_dt
} else {
bin.1.operating_datum_type(a_dt, b_dt)?
};
// avoid cast unified dtype to happen when all inputs quantized
// that can be unaligned at process time
a = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[a])?[0];
b = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[b])?[0];
}

let inputs = multi_rank_broadcast(builder, &[a, b])?;

let c_dt: Option<DatumType> = dt.first().cloned().and_then(|dt| dt);
Expand Down

0 comments on commit 3868b0e

Please sign in to comment.