diff --git a/nnef/src/registry.rs b/nnef/src/registry.rs index d4dd4c8653..3b7f819f6f 100644 --- a/nnef/src/registry.rs +++ b/nnef/src/registry.rs @@ -220,9 +220,13 @@ impl Registry { // mitigation of nnef "scalar" type mismatch with tract-core more // strict types - if (!a_dt.is_quantized() || !b_dt.is_quantized()) && a_dt != b_dt { - let operating_dt = if a_dt == TDim::datum_type() || b_dt == TDim::datum_type() { - TDim::datum_type() + if !a_dt.is_quantized() || !b_dt.is_quantized() { + let operating_dt = if a_dt == b_dt + && bin.1.operating_datum_type(a_dt, b_dt).map(|it| it == a_dt).unwrap_or(false) + { + a_dt + } else if a_dt == TDim::datum_type() || b_dt == TDim::datum_type() { + bin.1.operating_datum_type(a_dt, b_dt)? } else if builder.model.node(a.node).op_is::() { b_dt } else if builder.model.node(b.node).op_is::() { @@ -230,11 +234,10 @@ impl Registry { } else { bin.1.operating_datum_type(a_dt, b_dt)? }; - // avoid cast unified dtype to happen when all inputs quantized - // that can be unaligned at process time a = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[a])?[0]; b = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[b])?[0]; } + let inputs = multi_rank_broadcast(builder, &[a, b])?; let c_dt: Option = dt.first().cloned().and_then(|dt| dt);