Skip to content

Commit

Permalink
Try and isolate leak
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Aug 11, 2021
1 parent 1dc6bc0 commit d135154
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
18 changes: 18 additions & 0 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ impl Function {

let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);

// // This is a temporary patch to ensure that the arguments are correclty dropped.
// let args: Vec<ArgValue> = values.into_iter().zip(type_codes.into_iter()).map(|(value, type_code)| {
// ArgValue::from_tvm_value(value, type_code)
// }).collect();

// let mut objects_to_drop: Vec<crate::ObjectRef> = vec![];
// for arg in args {
// match arg {
// ArgValue::ObjectHandle(_) | ArgValue::ModuleHandle(_) | ArgValue::NDArrayHandle(_) => objects_to_drop.push(arg.try_into().unwrap()),
// _ => {}
// }
// }

// drop(objects_to_drop);

let obj: crate::ObjectRef = rv.clone().try_into().unwrap();
println!("rv: {}", obj.count());

Ok(rv)
}
}
Expand Down
27 changes: 15 additions & 12 deletions rust/tvm/examples/resnet/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,27 @@ fn main() -> anyhow::Result<()> {
"/deploy_lib.so"
)))?;

let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?;

// parse parameters and convert to TVMByteArray
let params: Vec<u8> = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?;

println!("param bytes: {}", params.len());

graph_rt.load_params(&params)?;
graph_rt.set_input("data", input)?;
graph_rt.run()?;
let mut output: Vec<f32>;

loop {
let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?;

// prepare to get the output
let output_shape = &[1, 1000];
let output = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1));
graph_rt.get_output_into(0, output.clone())?;
graph_rt.load_params(&params)?;
graph_rt.set_input("data", input.clone())?;
graph_rt.run()?;

// flatten the output as Vec<f32>
let output = output.to_vec::<f32>()?;
// prepare to get the output
let output_shape = &[1, 1000];
let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1));
graph_rt.get_output_into(0, output_nd.clone())?;

// flatten the output as Vec<f32>
output = output_nd.to_vec::<f32>()?;
}

// find the maximum entry in the output and its index
let (argmax, max_prob) = output
Expand Down

0 comments on commit d135154

Please sign in to comment.