Skip to content

Commit

Permalink
dump accepts --set for profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 2, 2024
1 parent a029d74 commit f47c6f5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
16 changes: 8 additions & 8 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,6 @@ fn main() -> TractResult<()> {
.help("Save the output tensor into a folder of nnef .dat files"),
)
.arg(Arg::new("steps").long("steps").help("Show all inputs and outputs"))
.arg(
Arg::new("set")
.long("set")
.takes_value(true)
.multiple_occurrences(true)
.number_of_values(1)
.help("Set a symbol value before running the model (--set S=12)"),
)
.arg(
Arg::new("save-steps")
.long("save-steps")
Expand Down Expand Up @@ -440,6 +432,14 @@ fn run_options(command: clap::Command) -> clap::Command {
.takes_value(true)
.help("Path to an input container (.npz). This sets tensor values."),
)
.arg(
Arg::new("set")
.long("set")
.takes_value(true)
.multiple_occurrences(true)
.number_of_values(1)
.help("Set a symbol value before running the model (--set S=12)"),
)
.arg(
Arg::new("input-from-nnef").long("input-from-nnef").takes_value(true).help(
"Path to a directory containing input tensors in NNEF format (.dat files). This sets tensor values.",
Expand Down
2 changes: 2 additions & 0 deletions cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ fn run_regular(
dispatch_model!(tract, |m| {
let plan = SimplePlan::new(m)?;
let mut state = SimpleState::new(plan)?;
/*
if let Some(set) = sub_matches.values_of("set") {
for set in set {
let mut tokens = set.split('=');
Expand All @@ -168,6 +169,7 @@ fn run_regular(
state.session_state.resolved_symbols.with(&sym, value);
}
}
*/
let inputs = tract_libcli::tensor::retrieve_or_make_inputs(tract, run_params)?;
let mut results = tvec!(vec!(); state.model().outputs.len());
let multiturn = inputs.len() > 1;
Expand Down
12 changes: 11 additions & 1 deletion cli/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,15 @@ pub fn run_params_from_subcommand(
let allow_float_casts: bool =
params.allow_float_casts || sub_matches.is_present("allow-float-casts");

Ok(RunParams { tensors_values: tv, allow_random_input, allow_float_casts })
let mut symbols = SymbolValues::default();
if let Some(set) = sub_matches.values_of("set") {
for set in set {
let (sym, value) = set.split_once('=').context("--set expect S=12 form")?;
let sym = params.tract_model.get_or_intern_symbol(sym);
let value: i64 = value.parse().with_context(|| format!("Can not parse symbol value in set {set}"))?;
symbols.set(&sym, value);
}
}

Ok(RunParams { tensors_values: tv, allow_random_input, allow_float_casts, symbols })
}
4 changes: 3 additions & 1 deletion libcli/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ pub struct RunParams {
pub tensors_values: TensorsValues,
pub allow_random_input: bool,
pub allow_float_casts: bool,
pub symbols: SymbolValues,
}

pub fn retrieve_or_make_inputs(
Expand Down Expand Up @@ -363,12 +364,13 @@ pub fn retrieve_or_make_inputs(
bail!("For input {}, can not reconcile model input fact {:?} with provided input {:?}", name, fact, value[0]);
};
} else if params.allow_random_input {
let fact = tract.outlet_typedfact(*input)?;
let mut fact:TypedFact = tract.outlet_typedfact(*input)?.clone();
info_once(format!("Using random input for input called {name:?}: {fact:?}"));
let tv = params
.tensors_values
.by_name(name)
.or_else(|| params.tensors_values.by_input_ix(ix));
fact.shape = fact.shape.iter().map(|dim| dim.eval(&params.symbols)).collect();
tmp.push(vec![crate::tensor::tensor_for_fact(&fact, None, tv)?.into()]);
} else {
bail!("Unmatched tensor {}. Fix the input or use \"--allow-random-input\" if this was intended", name);
Expand Down

0 comments on commit f47c6f5

Please sign in to comment.