Skip to content

Commit

Permalink
Merge 38b44b9 into 8373a7d
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jun 24, 2024
2 parents 8373a7d + 38b44b9 commit 6ac4547
Show file tree
Hide file tree
Showing 17 changed files with 236 additions and 163 deletions.
9 changes: 7 additions & 2 deletions core/src/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::internal::translator::Translate;
use crate::internal::*;
use crate::ops::array::{Pad, PadMode};
use crate::ops::binary::TypedBinOp;
use crate::ops::cast::Cast;
use crate::ops::cast::{cast, Cast};
use crate::ops::einsum::EinSum;
use crate::ops::element_wise::ElementWiseOp;
use crate::ops::konst::Const;
Expand Down Expand Up @@ -131,7 +131,12 @@ impl<T1: Datum + Float, T2: Datum + Float>
let new_op = if let Some(source) = node.op_as::<TypedSource>() {
Box::new(TypedSource::new(fact_float_precision_conversion::<T1, T2>(&source.fact)))
} else if let Some(konst) = node.op_as::<Const>() {
Box::new(Const(tensor_float_precision_conversion::<T1, T2>(&konst.0)))
if konst.0.datum_type() == T1::datum_type() {
let wire = target.add_const(format!("{}.{:?}", node.name, T1::datum_type()), konst.0.clone())?;
return target.wire_node(&node.name, cast(T2::datum_type()), &[wire]);
} else {
node.op.clone()
}
} else if let Some(cast) = node.op_as::<Cast>() {
if cast.to == T1::datum_type() {
Box::new(Cast { to: T2::datum_type() })
Expand Down
6 changes: 3 additions & 3 deletions core/src/model/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,16 +580,16 @@ where
let output_2 = successors.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {} => {}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
self.node_input_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
self.node_output_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
Expand Down
14 changes: 11 additions & 3 deletions core/src/model/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tract_data::itertools::{izip, Itertools};

use crate::internal::*;
use crate::model::*;
use crate::ops::dummy::Dummy;
use crate::ops::konst::Const;

/// A change to apply to a model.
Expand Down Expand Up @@ -265,6 +266,7 @@ where
} = self;
let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
let mut model_input_outlets = target.input_outlets()?.to_vec();
let mut new_nodes = HashSet::new();
for node in patch.nodes {
if <Graph<F, O>>::is_source(&node.op)
&& mapping.contains_key(&OutletId::new(node.id, 0))
Expand Down Expand Up @@ -294,6 +296,7 @@ where
}
let facts = outputs.into_iter().map(|of| of.fact).collect();
let added_node_id = target.add_node(name, op, facts)?;
new_nodes.insert(added_node_id);
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
}
Expand Down Expand Up @@ -324,9 +327,6 @@ where
target.set_outlet_label(replace_by, label)?;
}
}
if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
bail!("Duplicate usage of node as output");
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (&node, inputs) in all_inputs.iter().sorted() {
Expand Down Expand Up @@ -359,6 +359,14 @@ where
target.check_edges()?;
}
}
for n in new_nodes.iter() {
let node = &target.nodes[*n];
if target.nodes.iter().filter(|n| n.name == node.name && !n.op_is::<Dummy>()).count()
> 1
{
bail!("Patch created duplicate name with {node}");
}
}
Ok(())
}
}
1 change: 1 addition & 0 deletions core/src/model/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl<Ctx> Rewriter<Ctx> {
}
}
if done_anything {
model.prop_consts()?;
model.compact()?;
} else {
return Ok(());
Expand Down
54 changes: 51 additions & 3 deletions core/src/model/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::ops::konst::Const;
use crate::optim::OptimizerSession;
use crate::plan::{FrozenSimpleState, SimplePlan, SimpleState};
use crate::transform::ModelTransform;
use tract_data::UndeterminedSymbol;
use tract_num_traits::Zero;

/// A model with completely determined types and shapes.
Expand Down Expand Up @@ -46,6 +47,9 @@ impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
) -> TractResult<TVec<OutletId>> {
let op = op.into();
let name = name.into();
if self.nodes.iter().any(|n| n.name == name) {
bail!("Duplicate node name: {name}");
}
{
let input_facts = inputs
.iter()
Expand All @@ -55,7 +59,13 @@ impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
if op.is_stateless() && input_facts.len() > 0 {
if let Some(tensors) = input_facts
.iter()
.map(|f| f.konst.clone().map(|t| t.into_tvalue()))
.map(|f| {
f.konst
.as_ref()
.filter(|k| k.volume() > 16)
.cloned()
.map(|t| t.into_tvalue())
})
.collect::<Option<TVec<_>>>()
{
if let Ok(outputs) = op.eval_with_session(&SessionState::default(), tensors) {
Expand All @@ -77,8 +87,10 @@ impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
.output_facts(&input_facts)
.with_context(|| format!("in output_facts invocation for {name}: {}", op.name()))?;
for fact in &mut output_facts {
if fact.konst.is_none() && fact.shape.is_concrete() && fact.shape.volume().is_zero() {
let tensor = Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
if fact.konst.is_none() && fact.shape.is_concrete() && fact.shape.volume().is_zero()
{
let tensor =
Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
fact.konst = Some(tensor.into_arc_tensor());
}
}
Expand Down Expand Up @@ -193,6 +205,10 @@ impl TypedModel {
values.translate_model(self)
}

pub fn prop_consts(&mut self) -> TractResult<()> {
crate::optim::Optimizer::prop_consts().optimize(self)
}

/// Translate the graph to locally optimized operators (LIR or MIR ops).
pub fn optimize(&mut self) -> TractResult<()> {
crate::optim::Optimizer::codegen().optimize(self)
Expand All @@ -206,6 +222,37 @@ impl TypedModel {
pub fn axes_mapping(&self) -> TractResult<AxesMapping> {
crate::axes::for_model(self)
}

pub fn compute_const_facts(&mut self) -> TractResult<()> {
for n in self.eval_order()? {
let node = self.node(n);
let (inputs, outputs) = self.node_facts(n)?;
if node.op.is_stateless()
&& inputs.iter().all(|i| i.konst.is_some())
&& outputs.iter().any(|o| o.konst.is_none())
{
let inputs_ref =
inputs.iter().map(|f| f.konst.clone().unwrap().into_tvalue()).collect();
match node.op.eval_with_session(&SessionState::default(), inputs_ref) {
Ok(res) => {
drop(inputs);
drop(outputs);
for (ix, output) in res.into_iter().enumerate() {
self.nodes[n].outputs[ix].fact.konst = Some(output.into_arc_tensor());
}
}
Err(e) => {
if !e.root_cause().is::<UndeterminedSymbol>() {
Err(e).with_context(|| {
format!("Eager eval {} during const fact computation", self.node(n))
})?;
}
}
}
}
}
Ok(())
}
}

use crate::model::translator::Translate;
Expand All @@ -217,6 +264,7 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Sym
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
target.check_consistency()?;
let outlets = node.op.concretize_dims(source, node, target, mapping, self)?;
for &outlet in &outlets {
let fact = &mut target.nodes[outlet.node].outputs[outlet.slot].fact;
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/array/pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl TypedOp for Pad {
as_op!();

fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut fact = inputs[0].clone();
let mut fact = inputs[0].without_value();
if self.pads.len() != fact.rank() {
bail!("Inconsistent pad: input of rank {}, pads are: {:?}", fact.rank(), self.pads);
}
Expand Down
30 changes: 16 additions & 14 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ impl Conv {
}

let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
let bias_name = &model.node(bias.node).name;
let bias =
model.wire_node(format!("{name}.cast_bias"), cast(mmm.internal_type()), &[bias])?[0];
model.wire_node(format!("{bias_name}.cast"), cast(mmm.internal_type()), &[bias])?[0];
let wire = self.wire_mm_weights_bias(
model,
name,
Expand Down Expand Up @@ -730,10 +731,11 @@ impl Conv {

let operand = patch.tap_model(model, other_input)?;

let renamed = format!("{name}.{succ_name}");
let renamed_bias = format!("{name}.{succ_name}.bias");
let renamed_kernel = format!("{name}.{succ_name}.kernel");
bias = wire_reshape_bias_for_bin(
&mut patch,
format!("{renamed}.reshape_bias"),
format!("{renamed_bias}.reshape"),
bias,
1,
0,
Expand All @@ -742,7 +744,7 @@ impl Conv {

let operand = wire_reshape_bias_for_bin(
&mut patch,
format!("{renamed}.reshape_operand"),
format!("{renamed_bias}.reshape_operand"),
operand,
1,
0,
Expand All @@ -755,26 +757,26 @@ impl Conv {
operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] =
self.output_channels().to_dim();
let operand_for_kernel = patch.wire_node(
format!("{renamed}.reshape_operand_for_kernel"),
format!("{renamed_kernel}.reshape_operand"),
AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel),
&[operand],
)?[0];

if bin.0.is::<Sub>() && succ_outlet.slot == 0 {
bias = patch.wire_node(&renamed, sub(), &[bias, operand])?[0];
bias = patch.wire_node(&renamed_bias, sub(), &[bias, operand])?[0];
} else if bin.0.is::<Sub>() {
bias = patch.wire_node(&renamed, sub(), &[operand, bias])?[0];
bias = patch.wire_node(&renamed_bias, sub(), &[operand, bias])?[0];
} else if bin.0.is::<Div>() && succ_outlet.slot == 0 {
bias = patch.wire_node(&renamed, div(), &[bias, operand])?[0];
kernel = patch.wire_node(&renamed, div(), &[kernel, operand_for_kernel])?[0];
bias = patch.wire_node(&renamed_bias, div(), &[bias, operand])?[0];
kernel = patch.wire_node(&renamed_kernel, div(), &[kernel, operand_for_kernel])?[0];
} else if bin.0.is::<Div>() {
bias = patch.wire_node(&renamed, div(), &[operand, bias])?[0];
kernel = patch.wire_node(&renamed, div(), &[operand_for_kernel, kernel])?[0];
bias = patch.wire_node(&renamed_bias, div(), &[operand, bias])?[0];
kernel = patch.wire_node(&renamed_kernel, div(), &[operand_for_kernel, kernel])?[0];
} else if bin.0.is::<Add>() {
bias = patch.wire_node(&renamed, add(), &[bias, operand])?[0];
bias = patch.wire_node(&renamed_bias, add(), &[bias, operand])?[0];
} else if bin.0.is::<Mul>() {
bias = patch.wire_node(&renamed, mul(), &[bias, operand])?[0];
kernel = patch.wire_node(&renamed, mul(), &[kernel, operand_for_kernel])?[0];
bias = patch.wire_node(&renamed_bias, mul(), &[bias, operand])?[0];
kernel = patch.wire_node(&renamed_kernel, mul(), &[kernel, operand_for_kernel])?[0];
} else {
return Ok(None);
};
Expand Down
2 changes: 2 additions & 0 deletions core/src/ops/scan/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ impl Scan {
.context("building axis propagating patch")?
{
patch.apply(&mut new_body)?;
new_body.compute_const_facts()?;
// propagate axis injects new nodes at the end. last successor of input
// in new net will be the new succ
let new_body_scan_input = new_body.input_outlets()?[slot];
Expand Down Expand Up @@ -461,6 +462,7 @@ impl Scan {
.context("building axis propagating patch")?
{
patch.apply(&mut new_body)?;
new_body.prop_consts()?;
}
}
let emitter_outlet = new_body.output_outlets()?[mapping_ix];
Expand Down
9 changes: 7 additions & 2 deletions core/src/optim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ impl Optimizer {
Optimizer { steps: Some(steps), ..self }
}

pub fn prop_consts() -> Optimizer {
Optimizer::passes(vec![Box::new(PropConst)])
}

pub fn declutter() -> Optimizer {
Optimizer::passes(vec![
Box::new(PropConst),
Expand Down Expand Up @@ -166,8 +170,9 @@ impl<'o> OptimizerSession<'o> {
self.seen.insert(watchdog);
}
}
debug!("applying patch #{}: {}", self.counter, patch.context.iter().rev().join(" >> "),);
patch.apply(model)?;
let patch_name = patch.context.iter().rev().join(" >> ");
debug!("applying patch #{}: {patch_name}", self.counter);
patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
model
.check_consistency()
.context("Checking target model consistency after patching")?;
Expand Down
50 changes: 21 additions & 29 deletions core/src/optim/prop_const.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use tract_data::UndeterminedSymbol;

use crate::internal::*;
use crate::ops::konst::Const;
use crate::optim::OptimizerSession;

#[derive(Clone, Debug)]
Expand All @@ -19,36 +18,29 @@ impl super::TypedPass for PropConst {
let mut patch = TypedModelPatch::default();
for n in model.eval_order()? {
let node = model.node(n);
// bit of a hack here. pulse erase const from fact, so this little dance puts it back
if node.op_is::<Const>() && node.outputs[0].fact.konst.is_none() {
let k = node.op_as::<Const>().unwrap().0.clone();
let wire = patch.add_const(&node.name, k)?;
patch.shunt_outside(model, node.id.into(), wire)?;
}
if node.op.is_stateless() && !node.op_is::<Const>() {
if let Some(inputs) = model
.node_input_facts(n)?
.iter()
.map(|f| f.konst.clone().map(|t| t.into_tvalue()))
.collect()
{
match node.op.eval_with_session(&SessionState::default(), inputs) {
Ok(res) => {
for (ix, output) in res.into_iter().enumerate() {
let mut name = node.name.clone();
if ix > 0 {
name = format!("{name}.{ix}");
}
let wire = patch.add_const(name, output.into_arc_tensor())?;
patch.shunt_outside(model, (n, ix).into(), wire)?;
let (inputs, outputs) = model.node_facts(n)?;
if node.op.is_stateless()
&& inputs.iter().all(|i| i.konst.is_some())
&& outputs.iter().any(|o| o.konst.is_none())
{
let inputs =
inputs.iter().map(|f| f.konst.clone().unwrap().into_tvalue()).collect();
match node.op.eval_with_session(&SessionState::default(), inputs) {
Ok(res) => {
for (ix, output) in res.into_iter().enumerate() {
let mut name = node.name.clone();
if ix > 0 {
name = format!("{name}.{ix}");
}
let wire = patch.add_const(name, output.into_arc_tensor())?;
patch.shunt_outside(model, (n, ix).into(), wire)?;
}
Err(e) => {
if !e.root_cause().is::<UndeterminedSymbol>() {
Err(e).with_context(|| {
format!("Eager eval {} during optimisation", model.node(n))
})?;
}
}
Err(e) => {
if !e.root_cause().is::<UndeterminedSymbol>() {
Err(e).with_context(|| {
format!("Eager eval {} during optimisation", model.node(n))
})?;
}
}
}
Expand Down
Loading

0 comments on commit 6ac4547

Please sign in to comment.