Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dump accepts --set for profiling #1512

Merged
merged 2 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,6 @@ fn main() -> TractResult<()> {
.help("Save the output tensor into a folder of nnef .dat files"),
)
.arg(Arg::new("steps").long("steps").help("Show all inputs and outputs"))
.arg(
Arg::new("set")
.long("set")
.takes_value(true)
.multiple_occurrences(true)
.number_of_values(1)
.help("Set a symbol value before running the model (--set S=12)"),
)
.arg(
Arg::new("save-steps")
.long("save-steps")
Expand Down Expand Up @@ -440,6 +432,14 @@ fn run_options(command: clap::Command) -> clap::Command {
.takes_value(true)
.help("Path to an input container (.npz). This sets tensor values."),
)
.arg(
Arg::new("set")
.long("set")
.takes_value(true)
.multiple_occurrences(true)
.number_of_values(1)
.help("Set a symbol value before running the model (--set S=12)"),
)
.arg(
Arg::new("input-from-nnef").long("input-from-nnef").takes_value(true).help(
"Path to a directory containing input tensors in NNEF format (.dat files). This sets tensor values.",
Expand Down
11 changes: 0 additions & 11 deletions cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,6 @@ fn run_regular(
dispatch_model!(tract, |m| {
let plan = SimplePlan::new(m)?;
let mut state = SimpleState::new(plan)?;
if let Some(set) = sub_matches.values_of("set") {
for set in set {
let mut tokens = set.split('=');
let sym = tokens.next().context("--set expect S=12 form")?;
let value = tokens.next().context("--set expect S=12 form")?;
let sym = state.model().symbols.sym(sym).to_owned();
let value: i64 = value.parse().context("Can not parse symbol value in set")?;
state.session_state.resolved_symbols =
state.session_state.resolved_symbols.with(&sym, value);
}
}
let inputs = tract_libcli::tensor::retrieve_or_make_inputs(tract, run_params)?;
let mut results = tvec!(vec!(); state.model().outputs.len());
let multiturn = inputs.len() > 1;
Expand Down
12 changes: 11 additions & 1 deletion cli/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,15 @@ pub fn run_params_from_subcommand(
let allow_float_casts: bool =
params.allow_float_casts || sub_matches.is_present("allow-float-casts");

Ok(RunParams { tensors_values: tv, allow_random_input, allow_float_casts })
let mut symbols = SymbolValues::default();
if let Some(set) = sub_matches.values_of("set") {
for set in set {
let (sym, value) = set.split_once('=').context("--set expect S=12 form")?;
let sym = params.tract_model.get_or_intern_symbol(sym);
let value: i64 = value.parse().with_context(|| format!("Can not parse symbol value in set {set}"))?;
symbols.set(&sym, value);
}
}

Ok(RunParams { tensors_values: tv, allow_random_input, allow_float_casts, symbols })
}
4 changes: 3 additions & 1 deletion libcli/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ pub struct RunParams {
pub tensors_values: TensorsValues,
pub allow_random_input: bool,
pub allow_float_casts: bool,
pub symbols: SymbolValues,
}

pub fn retrieve_or_make_inputs(
Expand Down Expand Up @@ -363,12 +364,13 @@ pub fn retrieve_or_make_inputs(
bail!("For input {}, can not reconcile model input fact {:?} with provided input {:?}", name, fact, value[0]);
};
} else if params.allow_random_input {
let fact = tract.outlet_typedfact(*input)?;
let mut fact:TypedFact = tract.outlet_typedfact(*input)?.clone();
info_once(format!("Using random input for input called {name:?}: {fact:?}"));
let tv = params
.tensors_values
.by_name(name)
.or_else(|| params.tensors_values.by_input_ix(ix));
fact.shape = fact.shape.iter().map(|dim| dim.eval(&params.symbols)).collect();
tmp.push(vec![crate::tensor::tensor_for_fact(&fact, None, tv)?.into()]);
} else {
bail!("Unmatched tensor {}. Fix the input or use \"--allow-random-input\" if this was intended", name);
Expand Down
Loading