From 1853bcc36098d6965197f5b82670c4e4513004d0 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 18 Feb 2019 13:21:02 -0800 Subject: [PATCH] Fix the FInplaceIdentity (#2572) --- nnvm/src/pass/plan_memory.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 0261412d596ff..eff7fb9a59396 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -218,10 +218,14 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 && fignore_inputs[inode.source->op()]( inode.source->attrs).size() == inode.source->num_inputs()); + // Identity should only be true if shape.Size() and types match + bool real_identity = identity[ipair] && + shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && + dtype_vec[eid_out] == dtype_vec[eid_in]; if (taken[kv.first] == false && sid_out == GraphAllocator::kBadStorageID && sid_in >= 0 && - ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) && + ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) && entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && (dtype_vec[eid_out] == dtype_vec[eid_in] ||