diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 414fabae8b64..766b8b691f07 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1283,6 +1283,15 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXAutogradDropGrads(uint32_t num_var, NDArrayHandle* var_handles); +/*! + * \brief mark nonleaf NDArrays as variables during deferredcomputation + * \param num_nleafs number of nonleaf NDArrays + * \param cnt_var count of existing marked nonleaf variables + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle *nleaf_handles, + int num_nleafs, + int cnt_var); /*! * \brief unmark nonleaf NDArrays to free the memory * \param num_var number of variable NDArrays diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 42876f7bf445..65653cc9a890 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -290,6 +290,8 @@ class Imperative { void MarkVariables(const std::vector& variables, const std::vector& grad_reqs, const std::vector& gradients); + /*! \brief mark nonleaf variables during DC for computing gradients. */ + void MarkDCVariables(const std::vector& nleafs, int cnt_vars); /*! \brief unmark nonleaf variables to free the memory. */ void DropGrads(const std::vector& variables); /*! \brief compute the gradient of outputs w.r.t variables. */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index bed166a9307e..51fe5a9c579a 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -351,6 +351,8 @@ class NDArray { bool fresh_out_grad() const; /*! \return updated grad state in autograd_entry_ */ void set_fresh_out_grad(bool state) const; + /*! \brief copy the autograd_entry_ from src NDArray */ + void copy_autograd_entry_(const NDArray* src); /*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized * Throws an exception if the indices array shape is inconsistent * Returns false if the indices array is empty(nnz = 0) for csr/row_sparse diff --git a/python/mxnet/_ctypes/cached_op.py b/python/mxnet/_ctypes/cached_op.py index 509484b7c3e4..fd5d6a9c0c1e 100644 --- a/python/mxnet/_ctypes/cached_op.py +++ b/python/mxnet/_ctypes/cached_op.py @@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs): if not default_device: default_device = kwargs.pop('default_ctx', None) out = kwargs.pop('out', None) + nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])] if kwargs: raise TypeError( "CachedOp.__call__ got unexpected keyword argument(s): " + \ @@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs): *args, type_id, device_id, - *out_arg + len(out_arg), + *out_arg, + len(nleaf_vars), + *nleaf_vars ) if out is not None: return out diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index cff346b9f4aa..9c1a75997f9c 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -33,7 +33,8 @@ import json import numpy as np -from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB +from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \ + _as_list from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \ profiler as _profiler, device as _device from ..symbol.numpy import _symbol as np_symbol @@ -1091,6 +1092,7 @@ def __init__(self): self._backend_opts = {} self._partition_if_dynamic = True self._first_forward = True + self._nleaf_vars = OrderedDict() def __setattr__(self, name, value): """Registers parameters.""" @@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args): args_without_none = [ele for ele in args if ele is not None] cargs = [args_without_none[i] if is_arg else i.data() for is_arg, name, i in self._cached_op_args] - out = self._cached_op(*cargs) + out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values()) if isinstance(out, NDArray): out = [out] return _regroup(out, self._out_format) @@ -1678,6 +1680,49 @@ def reset_ctx(self, ctx): self.reset_device(ctx) + def intermediate(self, names, var_arrays_inp, grad_req='write'): + """Mark the intermediate variables. + + Parameters + ---------- + name : str or tuple[str], name of the registered intermediate variable + var_arrays_inp : ndarray or tuple[ndarray], the output of the expression + grad_req : str, gradient request + """ + if not self._active: + var_arrays = _as_list(var_arrays_inp) + names = _as_list(names) + self._nleaf_vars.update( + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) + else: + prev_val = dc.set_deferred_compute(False) + var_arrays = _as_list(var_arrays_inp) + names = _as_list(names) + # Prepare ctypes array types + import ctypes + var_handles_type = ctypes.c_void_p * len(var_arrays) + # Convert handles + var_handles = var_handles_type(*[arr.handle for arr in var_arrays]) + check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars))) + self._nleaf_vars.update( + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) + dc.set_deferred_compute(prev_val) + return var_arrays_inp + + def attach_grad_intermediate(self): + """Attach gradient to all the intermediate variables. + """ + for val in self._nleaf_vars.values(): + val.data().attach_grad(grad_req=val.grad_req) + + def get_intermediate(self, names): + """Get the intermediate variables by names + """ + if isinstance(names, list): + return [self._nleaf_vars[n] for n in names] + else: + return self._nleaf_vars[names] + class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models as feature extractors. For example, you may want to extract the output diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 1b396490a7fb..8cb4ac56b008 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -773,3 +773,40 @@ def grad_req(self, req): warnings.warn('Constant parameter "{}" does not support ' 'grad_req other than "null", and new value "{}" ' 'is ignored.'.format(self.name, req)) + +class Intermediate: + """A Container holding marked intermediate variables of Blocks. + + Parameters + ---------- + name : str. + Name of this parameter. It be used to retrieve the marked variables. + grad_req : {'write', 'add', 'null'}, default 'write' + Specifies how to update gradient to grad arrays. + + - ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`. + - ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need + to manually call ``zero_grad()`` to clear the gradient buffer before each + iteration when using this option. + - 'null' means gradient is not requested for this parameter. gradient arrays + will not be allocated. + """ + def __init__(self, name, data=None, grad_req='write'): + self._name = name + self._data = data + self._grad_req = grad_req + + def __repr__(self): + s = 'Intermediate name={name}' + return s.format(name=self._name) + + def data(self): + return self._data + + @property + def name(self): + return self._name + + @property + def grad_req(self): + return self._grad_req diff --git a/src/api/cached_op_api.cc b/src/api/cached_op_api.cc index 79494ea80bcf..4d9530b78129 100644 --- a/src/api/cached_op_api.cc +++ b/src/api/cached_op_api.cc @@ -44,19 +44,21 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke") ndinputs.push_back(static_cast(args[i])); } - std::vector ndoutputs; - ndoutputs.reserve(op->num_outputs()); - if (args[num_inputs + 4].type_code() == kNull) { - for (int i = 0; i < op->num_outputs(); ++i) - ndoutputs.push_back(new NDArray()); - } else { - int array_size = args_size - num_inputs - 4; - CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs() - << " outputs, but " << array_size << " was given."; - for (int i = num_inputs + 4; i < array_size; ++i) { - ndoutputs.push_back(args[i].operator mxnet::NDArray*()); - } - } + int num_outputs = args[num_inputs + 4]; + int num_nleafs = args[num_inputs + num_outputs + 5]; + std::vector ndoutputs; + ndoutputs.reserve(op->num_outputs()); + if (args[num_inputs + 5].type_code() == kNull) { + for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray()); + } else { + int array_size = args_size - num_inputs - num_nleafs - 6; + CHECK_EQ(array_size, op->num_outputs()) + << "CachedOp expects " << op->num_outputs() << " outputs, but " + << array_size << " was given."; + for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) { + ndoutputs.push_back(args[i].operator mxnet::NDArray*()); + } + } int default_dev_type; int default_dev_id; @@ -69,10 +71,17 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke") default_dev_id = ctx.dev_id; } - // construct default context - Context ctx = - Context::Create(static_cast(default_dev_type), default_dev_id); - op->Forward(op_shared, ndinputs, ndoutputs, ctx); + std::vector nleafs; + nleafs.reserve(num_nleafs); + for (int i = 0; i < num_nleafs; ++i) { + nleafs.push_back(static_cast(args[i + num_inputs + num_outputs + 6])); + } + op->set_nleafs(nleafs); + + // construct default context + Context ctx = Context::Create(static_cast(default_dev_type), + default_dev_id); + op->Forward(op_shared, ndinputs, ndoutputs, ctx); if (op->num_outputs() == 1) { *ret = ndoutputs[0]; diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index b91a997b7ce1..2215f2525a53 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles, *out = s; API_END_HANDLE_ERROR(delete s;); } + +int MXNDArrayMarkDCVariables(NDArrayHandle *nleaf_handles, int num_nleafs, int cnt_var) { + API_BEGIN(); + std::vector nleafs; + nleafs.reserve(num_nleafs); + for (int i = 0; i < num_nleafs; ++i) { + NDArray *array = reinterpret_cast(nleaf_handles[i]); + nleafs.emplace_back(array); + } + Imperative::Get()->MarkDCVariables(nleafs, cnt_var); + API_END(); +} diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 894ef09a1d16..a80ef76a02fc 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -39,9 +39,9 @@ nnvm::Symbol CachedOp::GetOptimizedSymbol() const { return ret.Copy(); } -CachedOp::CachedOp(const nnvm::Symbol& sym, - const std::vector >& flags) - : sym_(sym), flags_(flags) { +CachedOp::CachedOp( + const nnvm::Symbol& sym, + const std::vector >& flags) : sym_(sym), flags_(flags) { config_.Init(flags); this->dynamic_shape_checked_ = false; @@ -51,30 +51,25 @@ CachedOp::CachedOp(const nnvm::Symbol& sym, auto grad_graph = nnvm::Graph(); std::unordered_map fwd_input_to_grad_output; - CreateFullGraph(sym.Copy(), - &fwd_graph_, - &grad_graph, - &full_graph_, - &ograd_entries_, - &fwd_input_to_grad_output); + CreateFullGraph(sym.Copy(), &fwd_graph_, &grad_graph, &full_graph_, + &ograd_entries_, &fwd_input_to_grad_output); { - const auto& idx = fwd_graph_.indexed_graph(); + const auto& idx = fwd_graph_.indexed_graph(); bwd_output_reqs_ = std::vector(grad_graph.outputs.size(), kWriteTo); - inlining_ = !config_.static_alloc && - (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; + inlining_ = !config_.static_alloc && + (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; } SetInputIndices(fwd_graph_, config_.param_indices, &config_.data_indices); // Set the backward dependency vectors { - const auto& idx = full_graph_.indexed_graph(); - size_t num_forward_inputs = num_inputs(); + const auto& idx = full_graph_.indexed_graph(); + size_t num_forward_inputs = num_inputs(); size_t num_forward_outputs = num_outputs(); for (uint32_t i = 0; i < ograd_entries_.size(); ++i) { - if (!idx.exist(ograd_entries_[i].node.get())) - continue; + if (!idx.exist(ograd_entries_[i].node.get())) continue; bwd_ograd_dep_.push_back(i); } save_inputs_.resize(num_forward_inputs, false); @@ -94,15 +89,16 @@ CachedOp::CachedOp(const nnvm::Symbol& sym, CachedOp::~CachedOp() = default; -std::vector CachedOp::Gradient(const nnvm::ObjectPtr& node, - const std::vector& ograds) const { +std::vector CachedOp::Gradient( + const nnvm::ObjectPtr& node, + const std::vector& ograds) const { using namespace nnvm; static const auto _backward_CachedOp = Op::Get("_backward_CachedOp"); - static const auto _NoGrad = Op::Get("_NoGradient"); + static const auto _NoGrad = Op::Get("_NoGradient"); - auto p = Node::Create(); - p->attrs.op = _backward_CachedOp; - p->attrs.name = node->attrs.name + "_backward"; + auto p = Node::Create(); + p->attrs.op = _backward_CachedOp; + p->attrs.name = node->attrs.name + "_backward"; p->attrs.parsed = node->attrs.parsed; p->control_deps.push_back(node); p->inputs.reserve(bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size()); @@ -116,10 +112,10 @@ std::vector CachedOp::Gradient(const nnvm::ObjectPtr& node, ret.reserve(num_inputs()); const auto& auxs = mutable_input_nodes(); if (auxs.size()) { - auto nop = Node::Create(); - nop->attrs.op = _NoGrad; + auto nop = Node::Create(); + nop->attrs.op = _NoGrad; nop->attrs.name = "NoGradient"; - uint32_t k = 0; + uint32_t k = 0; for (const auto& i : fwd_graph_.indexed_graph().input_nodes()) { if (auxs.count(i)) { ret.emplace_back(nop); @@ -129,7 +125,7 @@ std::vector CachedOp::Gradient(const nnvm::ObjectPtr& node, } } else { for (uint32_t i = 0; i < num_inputs(); ++i) - ret.emplace_back(p, i, 0); + ret.emplace_back(p, i, 0); } return ret; } @@ -147,7 +143,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, CHECK_EQ(inputs.size(), num_inputs()); auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); nnvm::Graph& g = state.info.fwd_graph; ShapeVector shape_inputs(inputs.size()); @@ -158,7 +154,9 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, // If so, the pass will fail with `contain_dynamic_shape = true`, // This method is only called once, so the overhead is negligible. bool contain_dynamic_shape = false; - CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); + CheckAndInferShape(&g, std::move(shape_inputs), true, + {0, 0}, {0, 0}, + &contain_dynamic_shape); if (!config_.static_shape && erase_result) { g.attrs.erase("shape"); g.attrs.erase("shape_inputs"); @@ -166,10 +164,11 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, return contain_dynamic_shape; } -bool CachedOp::SetForwardGraph(const Context& default_ctx, - GraphInfo* info, - const bool recording, - const std::vector& inputs) { +bool CachedOp::SetForwardGraph( + const Context& default_ctx, + GraphInfo* info, + const bool recording, + const std::vector& inputs) { using namespace nnvm; using namespace imperative; CHECK_EQ(inputs.size(), num_inputs()); @@ -179,18 +178,19 @@ bool CachedOp::SetForwardGraph(const Context& default_ctx, DTypeVector dtype_inputs(inputs.size()); StorageTypeVector storage_type_inputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { - shape_inputs[i] = inputs[info->input_map[i]]->shape(); - dtype_inputs[i] = inputs[info->input_map[i]]->dtype(); + shape_inputs[i] = inputs[info->input_map[i]]->shape(); + dtype_inputs[i] = inputs[info->input_map[i]]->dtype(); storage_type_inputs[i] = inputs[info->input_map[i]]->storage_type(); } - bool match = true; + bool match = true; bool contain_dynamic_shape = false; - match &= - CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); + match &= CheckAndInferShape(&g, std::move(shape_inputs), true, + {0, 0}, {0, 0}, &contain_dynamic_shape); match &= CheckAndInferType(&g, std::move(dtype_inputs), true); exec::DevMaskVector dev_mask(g.indexed_graph().num_nodes(), default_ctx.dev_mask()); - match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(storage_type_inputs), true); + match &= CheckAndInferStorageType(&g, std::move(dev_mask), + std::move(storage_type_inputs), true); // When dynmaic shape exists, it is not feasible to plan memory ahead of time if (contain_dynamic_shape) { @@ -212,8 +212,7 @@ bool CachedOp::SetForwardGraph(const Context& default_ctx, const auto& stypes = g.GetAttr("storage_type"); CHECK_EQ(stypes.size(), storage.size()); for (size_t i = 0; i < stypes.size(); i++) { - if (stypes[i] != kDefaultStorage) - storage[i] = exec::kDynamicStorageID; + if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; } for (const auto i : idx.input_nodes()) { storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; @@ -222,11 +221,11 @@ bool CachedOp::SetForwardGraph(const Context& default_ctx, storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID; } - auto mem_plan = MXPlanMemory(&g, - std::move(storage), - g.GetAttr >(AddPrefix(prefix, REF_COUNT)), - AddPrefix(prefix, STORAGE_PLAN)); - g.attrs[AddPrefix(prefix, MEM_PLAN)] = std::make_shared(std::move(mem_plan)); + auto mem_plan = MXPlanMemory( + &g, std::move(storage), g.GetAttr >(AddPrefix(prefix, REF_COUNT)), + AddPrefix(prefix, STORAGE_PLAN)); + g.attrs[AddPrefix(prefix, MEM_PLAN)] = + std::make_shared(std::move(mem_plan)); return false; } @@ -237,7 +236,7 @@ void SetBackwardInputEid(const std::vector& bwd_in_dep, const std::vector& bwd_ograd_dep, const std::vector& ograd_entries, const nnvm::IndexedGraph& idx, - std::vector* bwd_input_eid) { + std::vector *bwd_input_eid) { for (const auto& i : bwd_ograd_dep) { auto ograd = ograd_entries[i]; if (idx.exist(ograd.node.get())) { @@ -256,24 +255,24 @@ void SetBackwardInputEid(const std::vector& bwd_in_dep, } } -bool CachedOp::SetBackwardGraph(GraphInfo* info, - const std::vector& reqs, - const std::vector& inputs, - bool detect_inplace_addto) { +bool CachedOp::SetBackwardGraph( + GraphInfo* info, + const std::vector& reqs, + const std::vector& inputs, + bool detect_inplace_addto) { using namespace nnvm; using namespace imperative; std::lock_guard lock(mutex_); Context default_ctx = inputs[0]->ctx(); - nnvm::Graph& g = info->full_graph; + nnvm::Graph& g = info->full_graph; if (info->bwd_output_reqs != reqs) { info->bwd_output_reqs = reqs; info->bwd_input_eid.clear(); - g = nnvm::Graph(); + g = nnvm::Graph(); g.outputs = info->fwd_graph.outputs; for (size_t i = 0; i < info->grad_graph.outputs.size(); ++i) { - if (info->bwd_output_reqs[i] == kNullOp) - continue; + if (info->bwd_output_reqs[i] == kNullOp) continue; g.outputs.emplace_back(info->grad_graph.outputs[i]); } g.attrs["context"] = std::make_shared( @@ -284,27 +283,25 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info, if (info->bwd_input_eid.size() != inputs.size()) { info->bwd_input_eid.clear(); - SetBackwardInputEid( - bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, info->ograd_entries, idx, &info->bwd_input_eid); + SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, + info->ograd_entries, idx, &info->bwd_input_eid); CHECK_EQ(inputs.size(), info->bwd_input_eid.size()); } - size_t num_forward_nodes = info->fwd_graph.indexed_graph().num_nodes(); + size_t num_forward_nodes = info->fwd_graph.indexed_graph().num_nodes(); size_t num_forward_entries = info->fwd_graph.indexed_graph().num_node_entries(); if (!g.attrs.count(AddPrefix(BACKWARD, REF_COUNT))) { std::vector ref_count(idx.num_node_entries(), 0); for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { - for (const auto& j : idx[i].inputs) - ++ref_count[idx.entry_id(j)]; + for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; } for (size_t i = 0; i < inputs.size(); ++i) { if (info->bwd_input_eid[i] != kEidNotExist) { ++ref_count[info->bwd_input_eid[i]]; } } - for (const auto& i : idx.outputs()) - ++ref_count[idx.entry_id(i)]; + for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared(std::move(ref_count)); } @@ -334,22 +331,24 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info, if (info->bwd_input_eid[i] == kEidNotExist) { continue; } - size_t oi = BwdOriginalInput(info->input_map, i); + size_t oi = BwdOriginalInput(info->input_map, i); shapes[info->bwd_input_eid[i]] = inputs[oi]->shape(); dtypes[info->bwd_input_eid[i]] = inputs[oi]->dtype(); stypes[info->bwd_input_eid[i]] = inputs[oi]->storage_type(); } std::pair node_range, entry_range; - node_range = {num_forward_nodes, idx.num_nodes()}; + node_range = {num_forward_nodes, idx.num_nodes()}; entry_range = {num_forward_entries, idx.num_node_entries()}; bool match = true; - match &= CheckAndInferShape(&g, std::move(shapes), false, node_range, entry_range); - match &= CheckAndInferType(&g, std::move(dtypes), false, node_range, entry_range); + match &= CheckAndInferShape(&g, std::move(shapes), false, + node_range, entry_range); + match &= CheckAndInferType(&g, std::move(dtypes), false, + node_range, entry_range); exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask()); - match &= CheckAndInferStorageType( - &g, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); + match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes), + false, node_range, entry_range); if (!match) { g.attrs.erase(AddPrefix(BACKWARD, MEM_PLAN)); @@ -360,15 +359,11 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info, StorageVector storage(idx.num_node_entries(), exec::kBadStorageID); const auto& bwd_stypes = g.GetAttr("storage_type"); for (size_t i = 0; i < bwd_stypes.size(); i++) { - if (bwd_stypes[i] != kDefaultStorage) - storage[i] = exec::kDynamicStorageID; + if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; } - for (size_t i = 0; i < num_forward_entries; ++i) - storage[i] = exec::kExternalStorageID; - for (const auto i : idx.input_nodes()) - storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; - for (const auto i : idx.outputs()) - storage[idx.entry_id(i)] = exec::kExternalStorageID; + for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID; + for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; + for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID; auto mem_plan = MXPlanMemory(&g, std::move(storage), @@ -382,7 +377,8 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info, return false; } -OpStatePtr CachedOp::GetCachedOpState(const Context& ctx) { +OpStatePtr CachedOp::GetCachedOpState( + const Context& ctx) { std::lock_guard lock(mutex_); for (const auto& i : cached_op_states_[ctx]) { // only create one state per device when not using static memory @@ -390,50 +386,52 @@ OpStatePtr CachedOp::GetCachedOpState(const Context& ctx) { return i; } } - auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_, inlining_); + auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_, + inlining_); cached_op_states_[ctx].push_back(state_ptr); return state_ptr; } -void CachedOp::StaticAllocMemory(const OpStatePtr& state_ptr, bool recording, bool keep_fwd) { +void CachedOp::StaticAllocMemory( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd) { using namespace nnvm; using namespace imperative; - auto& state = state_ptr.get_state(); - const auto& default_ctx = state.context; - nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; - const auto& idx = g.indexed_graph(); + auto& state = state_ptr.get_state(); + const auto& default_ctx = state.context; + nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; + const auto& idx = g.indexed_graph(); const std::string& graph_type = keep_fwd ? BACKWARD : (recording ? FULL : FORWARD); const auto& storage_plan_attr = AddPrefix(graph_type, STORAGE_PLAN); - const auto& storage_plan = g.GetAttr >(storage_plan_attr); - const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); + const auto& storage_plan = g.GetAttr >(storage_plan_attr); + const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); std::vector addto_entry; if (g.attrs.count("addto_entry")) { addto_entry = g.GetAttr >("addto_entry"); } - size_t start_eid = keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0; - size_t end_eid = idx.num_node_entries(); + size_t start_eid = + keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0; + size_t end_eid = idx.num_node_entries(); - if (!keep_fwd) - state.fwd_alloc = false; + if (!keep_fwd) state.fwd_alloc = false; state.bwd_alloc = false; for (size_t i = start_eid; i < state.buff.size(); ++i) { - state.buff[i] = NDArray(); - state.arrays[i] = &state.buff[i]; - state.array_reqs[i] = kNullOp; + state.buff[i] = NDArray(); + state.arrays[i] = &state.buff[i]; + state.array_reqs[i] = kNullOp; state.dynamic_entries[i] = false; } for (auto i : idx.input_nodes()) { auto eid = idx.entry_id(i, 0); - if (eid >= start_eid) - state.dynamic_entries[eid] = true; + if (eid >= start_eid) state.dynamic_entries[eid] = true; } for (auto i : idx.outputs()) { auto eid = idx.entry_id(i); - if (eid >= start_eid) - state.dynamic_entries[eid] = true; + if (eid >= start_eid) state.dynamic_entries[eid] = true; } for (size_t i = start_eid; i < end_eid; ++i) { @@ -450,15 +448,9 @@ void CachedOp::StaticAllocMemory(const OpStatePtr& state_ptr, bool recording, bo } auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool; - reuse_pool = imperative::AllocateMemory(g, - idx, - default_ctx, - start_eid, - end_eid, - mem_plan, - state.arrays, - &state.array_reqs, - std::move(reuse_pool)); + reuse_pool = imperative::AllocateMemory( + g, idx, default_ctx, start_eid, end_eid, mem_plan, + state.arrays, &state.array_reqs, std::move(reuse_pool)); state.recording = recording; if (keep_fwd) { @@ -468,23 +460,26 @@ void CachedOp::StaticAllocMemory(const OpStatePtr& state_ptr, bool recording, bo } } -void CachedOp::StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool keep_fwd) { +void CachedOp::StaticInitExec( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd) { using namespace nnvm; using namespace imperative; - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); const auto& default_ctx = state.context; - nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; - const auto& idx = g.indexed_graph(); + nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; + const auto& idx = g.indexed_graph(); std::vector skip_plus_node; if (g.attrs.count("skip_plus_node")) { skip_plus_node = g.GetAttr >("skip_plus_node"); } - size_t start_nid = keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0; - size_t end_nid = idx.num_nodes(); + size_t start_nid = + keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0; + size_t end_nid = idx.num_nodes(); - if (!keep_fwd) - state.fwd_exec_init = false; + if (!keep_fwd) state.fwd_exec_init = false; state.bwd_exec_init = false; for (size_t i = start_nid; i < state.execs.size(); ++i) { @@ -495,7 +490,7 @@ void CachedOp::StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool if (!config_.static_shape) { for (size_t i = start_nid; i < end_nid; ++i) { state.opr_segs[i].next_nid = i + 1; - state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i]; + state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i]; } } else { for (size_t i = start_nid; i < end_nid; ++i) { @@ -511,8 +506,7 @@ void CachedOp::StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) { skip = state.dynamic_entries[idx.entry_id(i, j)]; } - if (skip) - continue; + if (skip) continue; SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs); } @@ -530,14 +524,8 @@ void CachedOp::StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool bulk_size = 0; } - CreateEngineOpSeg(idx, - default_ctx, - start_nid, - end_nid, - bulk_size, - state.execs, - skip_plus_node, - &state.opr_segs); + CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, + state.execs, skip_plus_node, &state.opr_segs); } if (keep_fwd) { @@ -547,21 +535,22 @@ void CachedOp::StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool } } -void CachedOp::StaticRunOps(const Context& default_ctx, - const nnvm::Graph& g, - const OpStatePtr& state_ptr, - const std::vector& state_arrays, - size_t start_nid, - size_t end_nid) { - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); +void CachedOp::StaticRunOps( + const Context& default_ctx, + const nnvm::Graph& g, + const OpStatePtr& state_ptr, + const std::vector &state_arrays, + size_t start_nid, + size_t end_nid) { + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; + bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; bool is_training = Imperative::Get()->is_training(); - auto& state = state_ptr.get_state(); - const auto& idx = g.indexed_graph(); + auto& state = state_ptr.get_state(); + const auto& idx = g.indexed_graph(); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - const auto& op_execs = state.execs; + const auto& op_execs = state.execs; std::vector ndinputs, ndoutputs; mxnet::ShapeVector arg_shapes; @@ -569,20 +558,17 @@ void CachedOp::StaticRunOps(const Context& default_ctx, std::vector req; for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) { - if (op_execs[i]) - op_execs[i]->op_ctx.is_train = is_training; + if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training; } for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) { const auto& opr_seg = state.opr_segs[i]; - if (opr_seg.skip) - continue; + if (opr_seg.skip) continue; if (opr_seg.opr != nullptr) { Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling); } else { const nnvm::IndexedGraph::Node& node = idx[i]; - if (node.source->is_variable()) - continue; + if (node.source->is_variable()) continue; auto num_outputs = node.source->num_outputs(); ndinputs.clear(); ndinputs.reserve(node.inputs.size()); @@ -591,7 +577,7 @@ void CachedOp::StaticRunOps(const Context& default_ctx, CHECK(!ndinputs.back()->is_none()); } if (monitor_callback_ && monitor_all_) { - mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_); + mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_); } ndoutputs.clear(); ndoutputs.reserve(num_outputs); @@ -637,31 +623,28 @@ void CachedOp::StaticRunOps(const Context& default_ctx, state.op_states[fwd_node_id]); } else { Imperative::Get()->InvokeOp( - default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); + default_ctx, node.source->attrs, ndinputs, ndoutputs, req, + dispatch_mode); } if (monitor_callback_) { - mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_); + mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_); } } } } -#define INIT_DETACHED(x, y) \ - if (!y->is_none()) \ - x->InitDetached(y) +#define INIT_DETACHED(x, y) if (!y->is_none()) x->InitDetached(y) -static void PrepareOutputs(const nnvm::Graph& g, - const Context& default_ctx, - const std::vector& outputs, - std::vector* pArrays, - bool detach) { +static void PrepareOutputs(const nnvm::Graph& g, const Context& default_ctx, + const std::vector &outputs, + std::vector *pArrays, bool detach) { using namespace nnvm; const auto& dtypes = g.GetAttr("dtype"); const auto& shapes = g.GetAttr("shape"); const auto& stypes = g.GetAttr("storage_type"); const auto& idx = g.indexed_graph(); - auto& arrays = *pArrays; + auto &arrays = *pArrays; for (size_t i = 0; i < outputs.size(); ++i) { const auto eid = idx.entry_id(idx.outputs()[i]); // An input and an output may share the same array. @@ -670,22 +653,24 @@ static void PrepareOutputs(const nnvm::Graph& g, arrays[eid] = outputs[i]; if (arrays[eid]->is_none()) - arrays[eid]->ReInit( - static_cast(stypes[eid]), shapes[eid], default_ctx, dtypes[eid]); + arrays[eid]->ReInit(static_cast(stypes[eid]), + shapes[eid], default_ctx, dtypes[eid]); const nnvm::NodeAttrs& attrs = idx[idx.outputs()[i].node_id].source->attrs; - outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); + outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), + attrs.name); } } -OpStatePtr CachedOp::StaticForward(const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs) { +OpStatePtr CachedOp::StaticForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; bool recording = Imperative::Get()->is_recording(); auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); // Need to lock the mutex on the state, this allows // for multi context push of ops to dependency engine. @@ -695,11 +680,11 @@ OpStatePtr CachedOp::StaticForward(const Context& default_ctx, std::lock_guard lock(state.mutex); bool match = SetForwardGraph(default_ctx, &state.info, recording, inputs); - match = match && state.recording == recording; + match = match && state.recording == recording; - nnvm::Graph& g = state.info.fwd_graph; + nnvm::Graph& g = state.info.fwd_graph; const auto& idx = g.indexed_graph(); - if (!state.fwd_alloc || !match) { + if (!state.fwd_alloc || !match) { StaticAllocMemory(state_ptr, recording, false); } @@ -707,25 +692,25 @@ OpStatePtr CachedOp::StaticForward(const Context& default_ctx, // The input and output arrays should only be valid for this run, // so we shouldn't modify the state's array list. state.arrays_with_in_out = state.arrays; - auto& arrays = state.arrays_with_in_out; + auto& arrays = state.arrays_with_in_out; if (config_.static_shape) { for (auto i : config_.param_indices) { auto nid = idx.input_nodes()[i]; if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[state.info.input_map[i]])) { - match = false; + match = false; auto ptr = &state.buff[idx.entry_id(nid, 0)]; CHECK_EQ(arrays[idx.entry_id(nid, 0)], ptr); - *arrays[idx.entry_id(nid, 0)] = *inputs[state.info.input_map[i]]; + *arrays[idx.entry_id(nid, 0)] = *inputs[state.info.input_map[i]]; state.dynamic_entries[idx.entry_id(nid, 0)] = false; } } for (auto i : config_.data_indices) { - auto eid = idx.entry_id(idx.input_nodes()[i], 0); + auto eid = idx.entry_id(idx.input_nodes()[i], 0); arrays[eid] = inputs[state.info.input_map[i]]; } } else { for (size_t i = 0; i < num_inputs(); ++i) { - auto nid = idx.input_nodes()[i]; + auto nid = idx.input_nodes()[i]; arrays[idx.entry_id(nid, 0)] = inputs[state.info.input_map[i]]; } } @@ -740,29 +725,31 @@ OpStatePtr CachedOp::StaticForward(const Context& default_ctx, return recording ? state_ptr : OpStatePtr(); } -OpStatePtr CachedOp::DynamicForward(const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs, - bool use_naive_run) { + +OpStatePtr CachedOp::DynamicForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs, + bool use_naive_run) { using namespace nnvm; using namespace imperative; // Initialize bool recording = Imperative::Get()->is_recording(); - auto op_state = OpStatePtr::Create(); - auto& runtime = op_state.get_state(); + auto op_state = OpStatePtr::Create(); + auto& runtime = op_state.get_state(); { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); std::lock_guard lock(state.mutex); SetForwardGraph(default_ctx, &state.info, recording, inputs); runtime.info.fwd_graph = state.info.fwd_graph; runtime.info.input_map = state.info.input_map; } - nnvm::Graph& g = runtime.info.fwd_graph; + nnvm::Graph& g = runtime.info.fwd_graph; const auto& idx = g.indexed_graph(); - auto& buff = runtime.buff; - auto& states = runtime.op_states; + auto& buff = runtime.buff; + auto& states = runtime.op_states; // Allocate entries buff.resize(idx.num_node_entries()); @@ -773,54 +760,33 @@ OpStatePtr CachedOp::DynamicForward(const Context& default_ctx, arrays.push_back(&buffered_array); } std::vector array_reqs(arrays.size(), kWriteTo); - const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); const std::string& graph_type = recording ? FULL : FORWARD; std::vector ref_count = - g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); + g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); for (size_t i = 0; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) - array_reqs[i] = kNullOp; + if (ref_count[i] == 0) array_reqs[i] = kNullOp; } CollectInputOutputNDRefs(g, inputs, runtime.info.input_map, outputs, &arrays); if (!use_naive_run) { - const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); + const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); CreateGraphNDs(g, default_ctx, mem_plan, &array_reqs, &arrays); // If CachedOp is running in the inline mode, it uses RunGraph to record // computation; otherwise, CachedOp records computation itself. // So if it's not the inline mode, we disable recording. - RunGraph(false, - idx, - arrays, - 0, - idx.num_nodes(), - std::move(array_reqs), - std::move(ref_count), - &states, - dispatch_modes, - recording && inlining_, - nullptr, - monitor_callback_, - monitor_all_); + RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), + std::move(ref_count), &states, dispatch_modes, + recording && inlining_, nullptr, monitor_callback_, monitor_all_, + nleafs_); } else { mxnet::ShapeVector shapes = g.GetAttr("shape"); - NaiveRunGraph(false, - default_ctx, - idx, - arrays, - 0, - idx.num_nodes(), - std::move(array_reqs), - std::move(ref_count), - &states, - dispatch_modes, - recording && inlining_, - &shapes, - monitor_callback_, - monitor_all_); + NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(), + std::move(array_reqs), std::move(ref_count), &states, + dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_); { - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); auto copied_shape = shapes; std::lock_guard lock(state.mutex); state.info.fwd_graph.attrs["shape"] = std::make_shared(std::move(copied_shape)); @@ -830,22 +796,23 @@ OpStatePtr CachedOp::DynamicForward(const Context& default_ctx, return op_state; } -OpStatePtr CachedOp::Forward(const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_ctx) { +OpStatePtr CachedOp::Forward( + const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs, + const Context& default_ctx) { static const auto cached_op = nnvm::Op::Get("_CachedOp"); CHECK_EQ(inputs.size(), num_inputs()); // Assign the storage information for the input arguments. Similar to the // implementation in `graph_executor.cc`, we use `mutable_input_nodes()` to // distinguish between weight parameters and auxiliary states. - const auto& fwd_idx = fwd_graph_.indexed_graph(); + const auto& fwd_idx = fwd_graph_.indexed_graph(); const auto& mutable_input_nodes = fwd_idx.mutable_input_nodes(); for (size_t i = 0; i < fwd_idx.input_nodes().size(); ++i) { - const uint32_t nid = fwd_idx.input_nodes().at(i); - const nnvm::NodeAttrs& attrs = fwd_idx[nid].source->attrs; - const std::string& arg_name = attrs.name; + const uint32_t nid = fwd_idx.input_nodes().at(i); + const nnvm::NodeAttrs& attrs = fwd_idx[nid].source->attrs; + const std::string& arg_name = attrs.name; const std::string profiler_scope = common::NodeAttrsGetProfilerScope(attrs); if (mutable_input_nodes.count(nid)) { inputs[i]->AssignStorageInfo(profiler_scope + "aux_state:", arg_name); @@ -856,14 +823,16 @@ OpStatePtr CachedOp::Forward(const std::shared_ptr& op_ptr, { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); const auto& idx = state.info.fwd_graph.indexed_graph(); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx(), default_ctx) << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx << " while " - << idx[idx.input_nodes()[i]].source->attrs.name << " is on " << inputs[i]->ctx(); + << idx[idx.input_nodes()[0]].source->attrs.name + << " is on " << default_ctx << " while " + << idx[idx.input_nodes()[i]].source->attrs.name + << " is on " << inputs[i]->ctx(); } } @@ -872,9 +841,9 @@ OpStatePtr CachedOp::Forward(const std::shared_ptr& op_ptr, OpStatePtr op_state; try { if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) { - config_.is_dynamic = true; + config_.is_dynamic = true; config_.static_alloc = false; - op_state = DynamicForward(default_ctx, inputs, outputs, true); + op_state = DynamicForward(default_ctx, inputs, outputs, true); } else if (config_.static_alloc) { op_state = StaticForward(default_ctx, inputs, outputs); } else { @@ -889,43 +858,45 @@ OpStatePtr CachedOp::Forward(const std::shared_ptr& op_ptr, if (Imperative::Get()->is_recording() && !inlining_) { nnvm::NodeAttrs attrs; - attrs.op = cached_op; - attrs.name = "_cachedop"; + attrs.op = cached_op; + attrs.name = "_cachedop"; attrs.parsed = op_ptr; Imperative::Get()->RecordOp( - std::move(attrs), inputs, outputs, op_state, &save_inputs(), &save_outputs()); + std::move(attrs), inputs, outputs, op_state, + &save_inputs(), &save_outputs()); } return op_state; } -void CachedOp::DynamicBackward(const bool retain_graph, - const OpStatePtr& op_state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { +void CachedOp::DynamicBackward( + const bool retain_graph, + const OpStatePtr& op_state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; // Initialize Context default_ctx = outputs[0]->ctx(); - auto& runtime = op_state.get_state(); + auto& runtime = op_state.get_state(); { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); std::lock_guard lock(state.mutex); state.info.fwd_graph = runtime.info.fwd_graph; state.info.input_map = runtime.info.input_map; SetBackwardGraph(&state.info, reqs, inputs); - runtime.info.full_graph = state.info.full_graph; + runtime.info.full_graph = state.info.full_graph; runtime.info.bwd_input_eid = state.info.bwd_input_eid; } - nnvm::Graph& g = runtime.info.full_graph; + nnvm::Graph& g = runtime.info.full_graph; const auto& idx = g.indexed_graph(); - auto& buff = runtime.buff; - auto& states = runtime.op_states; + auto& buff = runtime.buff; + auto& states = runtime.op_states; size_t num_forward_outputs = runtime.info.fwd_graph.outputs.size(); - size_t num_forward_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes(); + size_t num_forward_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes(); size_t num_forward_entries = runtime.info.fwd_graph.indexed_graph().num_node_entries(); buff.resize(idx.num_node_entries()); std::vector arrays; @@ -940,8 +911,7 @@ void CachedOp::DynamicBackward(const bool retain_graph, arrays[runtime.info.bwd_input_eid[i]] = inputs[BwdOriginalInput(runtime.info.input_map, i)]; } for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) { - if (reqs[i] == kNullOp) - continue; + if (reqs[i] == kNullOp) continue; const auto eid = idx.entry_id(idx.outputs()[j++]); // An input and an output may share the same array. INIT_DETACHED(outputs[i], arrays[eid]); @@ -951,47 +921,29 @@ void CachedOp::DynamicBackward(const bool retain_graph, // Allocate NDArrays auto ref_count = g.GetAttr >(AddPrefix(BACKWARD, REF_COUNT)); if (retain_graph) { - for (size_t i = 0; i < num_forward_entries; ++i) - ++ref_count[i]; + for (size_t i = 0; i < num_forward_entries; ++i) ++ref_count[i]; } std::vector array_reqs(arrays.size(), kWriteTo); // set output reqs for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) { - if (reqs[i] == kNullOp) - continue; + if (reqs[i] == kNullOp) continue; array_reqs[idx.entry_id(idx.outputs()[j++])] = reqs[i]; } // set null reqs based on ref counts for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) - array_reqs[i] = kNullOp; + if (ref_count[i] == 0) array_reqs[i] = kNullOp; } - const auto& mem_plan = g.GetAttr(AddPrefix(BACKWARD, MEM_PLAN)); - AllocateMemory(g, - idx, - default_ctx, - num_forward_entries, - idx.num_node_entries(), - mem_plan, - arrays, - &array_reqs); + const auto& mem_plan = g.GetAttr(AddPrefix(BACKWARD, MEM_PLAN)); + AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(), + mem_plan, arrays, &array_reqs); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - RunGraph(retain_graph, - idx, - arrays, - num_forward_nodes, - idx.num_nodes(), - std::move(array_reqs), - std::move(ref_count), - &states, - dispatch_modes, - Imperative::Get()->is_recording(), - nullptr, - monitor_callback_); + RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, + Imperative::Get()->is_recording(), nullptr, monitor_callback_); if (retain_graph) { buff.resize(num_forward_entries); @@ -1001,11 +953,12 @@ void CachedOp::DynamicBackward(const bool retain_graph, } } -void CachedOp::StaticBackward(const bool retain_graph, - const OpStatePtr& state_ptr, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { +void CachedOp::StaticBackward( + const bool retain_graph, + const OpStatePtr& state_ptr, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; @@ -1016,8 +969,8 @@ void CachedOp::StaticBackward(const bool retain_graph, bool match = SetBackwardGraph(&state.info, reqs, inputs, true); - nnvm::Graph& g = state.info.full_graph; - const auto& idx = g.indexed_graph(); + nnvm::Graph& g = state.info.full_graph; + const auto& idx = g.indexed_graph(); auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes(); if (!state.bwd_alloc || !match) { @@ -1028,41 +981,38 @@ void CachedOp::StaticBackward(const bool retain_graph, // The input and output arrays should only be valid for this run, // so we shouldn't modify the state's array list. state.arrays_with_in_out = state.arrays; - auto& arrays = state.arrays_with_in_out; + auto& arrays = state.arrays_with_in_out; for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) { auto eid = state.info.bwd_input_eid[i]; - if (eid == kEidNotExist || !state.dynamic_entries[eid]) - continue; + if (eid == kEidNotExist || !state.dynamic_entries[eid]) continue; arrays[eid] = inputs[BwdOriginalInput(state.info.input_map, i)]; } if (config_.static_shape) { for (auto i : config_.param_indices) { const auto iter = state.info.fwd_input_to_grad_output.find(i); - if (iter == state.info.fwd_input_to_grad_output.end()) - continue; + if (iter == state.info.fwd_input_to_grad_output.end()) continue; auto entry = state.info.grad_graph.outputs[iter->second]; - if (!idx.exist(entry.node.get())) - continue; + if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); - if ((!arrays[eid]->IsSame(*outputs[iter->second]) && state.array_reqs[eid] != kNullOp) || + if ((!arrays[eid]->IsSame(*outputs[iter->second]) && + state.array_reqs[eid] != kNullOp) || !(state.array_reqs[eid] == reqs[iter->second])) { - match = false; + match = false; state.array_reqs[eid] = reqs[iter->second]; // An input and an output may share the same array. INIT_DETACHED(outputs[iter->second], arrays[eid]); - *arrays[eid] = *outputs[iter->second]; + *arrays[eid] = *outputs[iter->second]; state.dynamic_entries[eid] = false; } } for (auto i : config_.data_indices) { const auto iter = state.info.fwd_input_to_grad_output.find(i); - if (iter == state.info.fwd_input_to_grad_output.end()) - continue; + if (iter == state.info.fwd_input_to_grad_output.end()) continue; auto entry = state.info.grad_graph.outputs[iter->second]; - if (!idx.exist(entry.node.get())) - continue; + if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); + state.array_reqs[eid] = reqs[iter->second]; // An input and an output may share the same array. INIT_DETACHED(outputs[iter->second], arrays[eid]); arrays[eid] = outputs[iter->second]; @@ -1070,9 +1020,9 @@ void CachedOp::StaticBackward(const bool retain_graph, } else { for (size_t i = 0; i < state.info.grad_graph.outputs.size(); ++i) { auto entry = state.info.grad_graph.outputs[i]; - if (!idx.exist(entry.node.get())) - continue; + if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); + state.array_reqs[eid] = reqs[i]; // An input and an output may share the same array. INIT_DETACHED(outputs[i], arrays[eid]); arrays[eid] = outputs[i]; @@ -1086,16 +1036,17 @@ void CachedOp::StaticBackward(const bool retain_graph, StaticRunOps(default_ctx, g, state_ptr, arrays, num_forward_nodes, idx.num_nodes()); } -void CachedOp::Backward(const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { - const auto& fwd_idx = fwd_graph_.indexed_graph(); - const auto& full_idx = full_graph_.indexed_graph(); +void CachedOp::Backward( + const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { + const auto& fwd_idx = fwd_graph_.indexed_graph(); + const auto& full_idx = full_graph_.indexed_graph(); const auto& mutable_input_nodes = fwd_idx.mutable_input_nodes(); for (size_t i = 0, j = 0; i < fwd_idx.input_nodes().size(); ++i) { - const uint32_t nid = fwd_idx.input_nodes().at(i); + const uint32_t nid = fwd_idx.input_nodes().at(i); const std::string& arg_name = fwd_idx[nid].source->attrs.name; const std::string profiler_scope = common::NodeAttrsGetProfilerScope(fwd_idx[nid].source->attrs); @@ -1104,9 +1055,10 @@ void CachedOp::Backward(const bool retain_graph, } outputs[j++]->AssignStorageInfo(profiler_scope + "arg_grad:", arg_name); } - for (size_t i = fwd_idx.input_nodes().size(), j = 0; i < full_idx.input_nodes().size(); ++i) { - const nnvm::NodeAttrs& attrs = full_idx[full_idx.input_nodes().at(i)].source->attrs; - const std::string& entry_name = attrs.name; + for (size_t i = fwd_idx.input_nodes().size(), j = 0; + i < full_idx.input_nodes().size(); ++i) { + const nnvm::NodeAttrs& attrs = full_idx[full_idx.input_nodes().at(i)].source->attrs; + const std::string& entry_name = attrs.name; const std::string profiler_scope = common::NodeAttrsGetProfilerScope(attrs); inputs[j++]->AssignStorageInfo(profiler_scope, entry_name); } @@ -1160,11 +1112,11 @@ void CachedOpForward(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CachedOpActualState& s = state_ptr.get_state(); - std::vector in_bufs = inputs; + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; std::vector out_bufs = outputs; - std::vector in_ptrs(in_bufs.size()); - std::vector out_ptrs(out_bufs.size()); + std::vector in_ptrs(in_bufs.size()); + std::vector out_ptrs(out_bufs.size()); for (size_t i = 0; i < in_ptrs.size(); i++) in_ptrs[i] = &in_bufs[i]; for (size_t i = 0; i < out_ptrs.size(); i++) @@ -1184,7 +1136,7 @@ void CachedOpForward(const OpStatePtr& state_ptr, orig_is_train = Imperative::Get()->is_training(); CHECK(inputs.size() > 0) << "cached op forward requires at least 1 input"; Context default_ctx = inputs[0].ctx(); - s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs, default_ctx); + s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs, default_ctx); Imperative::Get()->set_is_training(orig_is_train); Imperative::Get()->set_is_recording(orig_is_record); // The arrays in out_ptrs may be changed by CachedOp. @@ -1205,29 +1157,29 @@ void CachedOpBackward(const OpStatePtr& state_ptr, const std::vector& outputs) { using namespace nnvm; using namespace imperative; - CachedOpActualState& s = state_ptr.get_state(); - std::vector in_bufs = inputs; + CachedOpActualState &s = state_ptr.get_state(); + std::vector in_bufs = inputs; std::vector out_bufs = outputs; - std::vector in_ptrs; - std::vector out_ptrs; + std::vector in_ptrs; + std::vector out_ptrs; CHECK_EQ(s.op->num_backward_inputs(), inputs.size()); in_ptrs.reserve(s.op->num_backward_inputs()); out_ptrs.reserve(s.op->num_inputs()); - const std::vector& save_inputs = s.op->save_inputs(); - const std::vector& save_outputs = s.op->save_outputs(); - size_t bwd_in_dep = s.op->num_inputs(); - size_t bwd_out_dep = s.op->num_outputs(); + const std::vector &save_inputs = s.op->save_inputs(); + const std::vector &save_outputs = s.op->save_outputs(); + size_t bwd_in_dep = s.op->num_inputs(); + size_t bwd_out_dep = s.op->num_outputs(); CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep); size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep; // Find inputs, outputs and ograds auto ograds_begin = in_bufs.begin(); - auto ograds_end = in_bufs.begin() + bwd_ograd_dep; - auto in_begin = ograds_end; - auto in_end = in_begin + bwd_in_dep; - auto out_begin = in_end; - auto out_end = in_bufs.end(); + auto ograds_end = in_bufs.begin() + bwd_ograd_dep; + auto in_begin = ograds_end; + auto in_end = in_begin + bwd_in_dep; + auto out_begin = in_end; + auto out_end = in_bufs.end(); for (auto it = ograds_begin; it != ograds_end; it++) in_ptrs.push_back(&(*it)); @@ -1280,10 +1232,11 @@ void CachedOpBackward(const OpStatePtr& state_ptr, /* * Register the callback to be called when the operator is executed */ -void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, bool monitor_all) { - CHECK(callback) << "invalid callback"; - monitor_callback_ = callback; - monitor_all_ = monitor_all; +void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, + bool monitor_all) { + CHECK(callback) << "invalid callback"; + monitor_callback_ = callback; + monitor_all_ = monitor_all; } OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, @@ -1297,19 +1250,19 @@ OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector* in_attrs, - std::vector* out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { using namespace imperative; nnvm::Graph g(full_graph_); - const auto& idx = g.indexed_graph(); - const auto& outputs = idx.outputs(); + const auto& idx = g.indexed_graph(); + const auto &outputs = idx.outputs(); const size_t num_forward_outputs = fwd_graph_.outputs.size(); CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size()); // Construct bwd_input_eid std::vector bwd_input_eid; - SetBackwardInputEid( - bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, ograd_entries_, idx, &bwd_input_eid); + SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, + ograd_entries_, idx, &bwd_input_eid); CHECK_EQ(in_attrs->size(), bwd_input_eid.size()); // Prepare stypes and contexts based on inputs @@ -1375,98 +1328,92 @@ size_t CachedOp::BwdOriginalInput(const std::vector& input_map, size_t n } NNVM_REGISTER_OP(_CachedOp) - .set_num_inputs([](const NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_inputs(); - }) - .set_num_outputs([](const NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_outputs(); - }) - .set_attr_parser(CachedOpParamParser) - .set_attr("FGradient", - [](const nnvm::ObjectPtr& n, - const std::vector& ograds) { - const CachedOpPtr& op = nnvm::get(n->attrs.parsed); - return op->Gradient(n, ograds); - }) - .set_attr("FListInputNames", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ListForwardInputNames(); - }) - .set_attr("FListOutputNames", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = - nnvm::get(attrs.parsed); - return op->ListForwardOutputNames(); - }) - .set_attr("FCreateOpState", CreateCachedOpState) - .set_attr("FInferShape", - [](const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_shapes, - mxnet::ShapeVector* out_shapes) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpShapeHelper( - op->GetForwardSym(), in_shapes, out_shapes); - }) - .set_attr( - "FInferType", - [](const nnvm::NodeAttrs& attrs, std::vector* in_types, std::vector* out_types) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); - }) - .set_attr( - "FInferStorageType", - [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_stypes, - std::vector* out_stypes) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpStorageTypeHelper( - op->GetForwardSym(), dev_mask, dispatch_mode, in_stypes, out_stypes); - }) - .set_attr("FStatefulComputeEx", CachedOpForward) - .set_attr("FStatefulComputeEx", CachedOpForward) - .set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpMutableInputsHelper( - op->GetForwardSym()); - }) - .set_attr("FResourceRequest", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpResourceRequestHelper( - op->GetForwardSym()); - }) - .set_attr("FExecType", op::DefaultSubgraphOpExecType) - .add_argument("data", "NDArray-or-Symbol[]", "input data list"); +.set_num_inputs([](const NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_inputs(); + }) +.set_num_outputs([](const NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_outputs(); + }) +.set_attr_parser(CachedOpParamParser) +.set_attr("FGradient", + [](const nnvm::ObjectPtr& n, const std::vector& ograds) { + const CachedOpPtr& op = nnvm::get(n->attrs.parsed); + return op->Gradient(n, ograds); + }) +.set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardInputNames(); + }) +.set_attr("FListOutputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardOutputNames(); + }) +.set_attr("FCreateOpState", CreateCachedOpState) +.set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shapes, + mxnet::ShapeVector *out_shapes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes); + }) +.set_attr("FInferType", + [](const nnvm::NodeAttrs& attrs, + std::vector *in_types, + std::vector *out_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); + }) +.set_attr("FInferStorageType", + [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(), + dev_mask, dispatch_mode, + in_stypes, out_stypes); + }) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FStatefulComputeEx", CachedOpForward) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym()); + }) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym()); + }) +.set_attr("FExecType", op::DefaultSubgraphOpExecType) +.add_argument("data", "NDArray-or-Symbol[]", "input data list"); NNVM_REGISTER_OP(_backward_CachedOp) - .set_num_inputs([](const NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_backward_inputs(); - }) - .set_num_outputs([](const NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_inputs() - op->mutable_input_nodes().size(); - }) - .set_attr("FInferStorageType", - [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_attrs, - std::vector* out_attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->BackwardStorageType( - attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); - }) - .set_attr("FStatefulComputeEx", CachedOpBackward) - .set_attr("FStatefulComputeEx", CachedOpBackward) - .set_attr("FExecType", op::DefaultSubgraphOpExecType) - .set_attr("TIsLayerOpBackward", true) - .set_attr("TIsBackward", true); +.set_num_inputs([](const NodeAttrs& attrs){ + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_backward_inputs(); + }) +.set_num_outputs([](const NodeAttrs& attrs){ + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_inputs() - op->mutable_input_nodes().size(); + }) +.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); + }) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FStatefulComputeEx", CachedOpBackward) +.set_attr("FExecType", op::DefaultSubgraphOpExecType) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true); } // namespace mxnet diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 079a56e20a12..bb0f8ffad2c2 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -37,41 +37,41 @@ namespace mxnet { namespace { -static const char FULL[] = "full"; -static const char FORWARD[] = "forward"; -static const char BACKWARD[] = "backward"; -static const char REF_COUNT[] = "ref_count"; -static const char MEM_PLAN[] = "mem_plan"; -static const char STORAGE_PLAN[] = "storage_plan"; - -std::string AddPrefix(const std::string& prefix, const std::string& s) { + static const char FULL[] = "full"; + static const char FORWARD[] = "forward"; + static const char BACKWARD[] = "backward"; + static const char REF_COUNT[] = "ref_count"; + static const char MEM_PLAN[] = "mem_plan"; + static const char STORAGE_PLAN[] = "storage_plan"; + +std::string AddPrefix(const std::string& prefix, + const std::string& s) { return prefix + "_" + s; } nnvm::NodeEntry AggregateGradient(std::vector&& v) { using nnvm::Op; - static size_t inplace_sum_cap = dmlc::GetEnv("MXNET_EXEC_INPLACE_GRAD_SUM_CAP", 8); + static size_t inplace_sum_cap = dmlc::GetEnv("MXNET_EXEC_INPLACE_GRAD_SUM_CAP", 8); static const Op* ewise_plus_op = Op::Get("_grad_add"); - static const Op* ewise_sum_op = Op::Get("ElementWiseSum"); - static const Op* identity_op = Op::Get("identity"); - static const Op* zeros_op = Op::Get("_zeros"); + static const Op* ewise_sum_op = Op::Get("ElementWiseSum"); + static const Op* identity_op = Op::Get("identity"); + static const Op* zeros_op = Op::Get("_zeros"); static const Op* zeros_like_op = Op::Get("zeros_like"); if (v.empty()) { nnvm::ObjectPtr ng = nnvm::Node::Create(); - ng->attrs.op = Op::Get("_zeros_without_dtype"); - ng->attrs.name = "zeros_without_dtype"; + ng->attrs.op = Op::Get("_zeros_without_dtype"); + ng->attrs.name = "zeros_without_dtype"; ng->attrs.op->attr_parser(&(ng->attrs)); return nnvm::NodeEntry(std::move(ng), 0, 0); } // remove zero in the sum. at least keep 1. auto begin = std::remove_if(v.begin(), v.end(), [](const nnvm::NodeEntry& nodeEntry) { - CHECK(nodeEntry.node); - return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op; + CHECK(nodeEntry.node); + return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op; }); - if (begin == v.begin()) - ++begin; + if (begin == v.begin()) ++begin; v.erase(begin, v.end()); CHECK(!v.empty()); @@ -79,9 +79,9 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { return std::move(v[0]); } else { if (v.size() < inplace_sum_cap) { - nnvm::ObjectPtr sum_node = nnvm::Node::Create(); - sum_node->attrs.op = ewise_sum_op; - sum_node->attrs.name = "sum_grad"; + nnvm::ObjectPtr sum_node = nnvm::Node::Create(); + sum_node->attrs.op = ewise_sum_op; + sum_node->attrs.name = "sum_grad"; sum_node->attrs.dict["num_args"] = std::to_string(v.size()); sum_node->attrs.op->attr_parser(&(sum_node->attrs)); sum_node->inputs = std::move(v); @@ -105,24 +105,24 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { // the node entries v passed in here are of the same node of // op _identity_with_attr_like_rhs. We should skip adding a node // to its own control_deps. - if (v[i - 1].node != v[i].node) { + if (v[i-1].node != v[i].node) { v[i].node->control_deps.push_back(ret.node); } std::ostringstream os; os << "sum_grad_" << i; nnvm::ObjectPtr x = nnvm::Node::Create(); - x->attrs.op = ewise_plus_op; - x->attrs.name = os.str(); - x->inputs = {ret, v[i]}; - ret = nnvm::NodeEntry(std::move(x), 0, 0); + x->attrs.op = ewise_plus_op; + x->attrs.name = os.str(); + x->inputs = {ret, v[i]}; + ret = nnvm::NodeEntry(std::move(x), 0, 0); } // identity node is used to avoid exposure of dummy plus node // when its output get assigned to another space. nnvm::ObjectPtr id_node = nnvm::Node::Create(); - id_node->attrs.op = identity_op; - id_node->attrs.name = "sum_grad_final"; - id_node->inputs = {ret}; + id_node->attrs.op = identity_op; + id_node->attrs.name = "sum_grad_final"; + id_node->inputs = {ret}; return nnvm::NodeEntry{id_node, 0, 0}; } } @@ -142,7 +142,7 @@ void CollectInputOutputNDRefs(const nnvm::Graph& g, const std::vector& input_map, const std::vector& outputs, std::vector* arrays) { - const auto& idx = g.indexed_graph(); + const auto& idx = g.indexed_graph(); size_t num_inputs = idx.input_nodes().size(); for (size_t i = 0; i < num_inputs; ++i) { (*arrays)[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[input_map[i]]; @@ -168,24 +168,27 @@ void CreateGraphNDs(const nnvm::Graph& g, std::vector* array_reqs, std::vector* arrays) { const auto& idx = g.indexed_graph(); - mxnet::imperative::AllocateMemory( - g, idx, default_ctx, 0, idx.num_node_entries(), mem_plan, *arrays, array_reqs); - const auto& dtypes = g.GetAttr("dtype"); - const auto& shapes = g.GetAttr("shape"); - const auto& stypes = g.GetAttr("storage_type"); + mxnet::imperative::AllocateMemory(g, idx, default_ctx, 0, + idx.num_node_entries(), mem_plan, *arrays, + array_reqs); + const auto &dtypes = g.GetAttr("dtype"); + const auto &shapes = g.GetAttr("shape"); + const auto &stypes = g.GetAttr("storage_type"); for (size_t i = 0; i < idx.outputs().size(); ++i) { auto eid = idx.entry_id(idx.outputs()[i]); if (!(*arrays)[eid]->is_none()) continue; - *((*arrays)[eid]) = NDArray( - static_cast(stypes[eid]), shapes[eid], default_ctx, true, dtypes[eid]); + *((*arrays)[eid]) = NDArray(static_cast(stypes[eid]), + shapes[eid], default_ctx, true, dtypes[eid]); const nnvm::NodeAttrs& attrs = idx[idx.outputs()[i].node_id].source->attrs; - (*arrays)[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); + (*arrays)[eid]->AssignStorageInfo( + common::NodeAttrsGetProfilerScope(attrs), + attrs.name); } } /* \brief create a forward graph from they Symbol */ -void CreateForwardGraph(const nnvm::Symbol& sym, nnvm::Graph* fwd_graph) { +void CreateForwardGraph(const nnvm::Symbol &sym, nnvm::Graph *fwd_graph) { using namespace nnvm; static const auto _copy_op = Op::Get("_copy"); NodeEntryMap dedup_out; @@ -193,12 +196,12 @@ void CreateForwardGraph(const nnvm::Symbol& sym, nnvm::Graph* fwd_graph) { // to graph outputs. Since node entry stores information about the node // as well as the input node of the graph, a graph can be recreated from a // symbol by just copying the outputs - for (const NodeEntry& nodeEntry : sym.outputs) { + for (const NodeEntry &nodeEntry : sym.outputs) { if (dedup_out.find(nodeEntry) != dedup_out.end()) { ObjectPtr copy_node = Node::Create(); copy_node->attrs.op = _copy_op; - copy_node->attrs.name = - nodeEntry.node->attrs.name + "_copy" + std::to_string(dedup_out[nodeEntry]++); + copy_node->attrs.name = nodeEntry.node->attrs.name + "_copy" + + std::to_string(dedup_out[nodeEntry]++); copy_node->inputs.emplace_back(nodeEntry); if (_copy_op->attr_parser != nullptr) { _copy_op->attr_parser(&(copy_node->attrs)); @@ -222,15 +225,15 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph, static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; ograd_entries->reserve(fwd_graph->outputs.size()); for (size_t i = 0; i < fwd_graph->outputs.size(); ++i) { - nnvm::ObjectPtr np = Node::Create(); - const nnvm::NodeAttrs& attrs = fwd_graph->outputs[i].node->attrs; - np->attrs.name = attrs.name + "_head_grad"; + nnvm::ObjectPtr np = Node::Create(); + const nnvm::NodeAttrs& attrs = fwd_graph->outputs[i].node->attrs; + np->attrs.name = attrs.name + "_head_grad"; np->attrs.dict["__profiler_scope__"] = common::NodeAttrsGetProfilerScope(attrs); ograd_entries->emplace_back(np); } std::vector xs; - const IndexedGraph& indexed_graph = fwd_graph->indexed_graph(); + const IndexedGraph &indexed_graph = fwd_graph->indexed_graph(); // Create vector of inputs to be passed to the gradient pass for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) { const uint32_t node_id = indexed_graph.input_nodes()[i]; @@ -248,15 +251,11 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph, // There are inputs in computation graph that require gradients if (!xs.empty()) { try { - *grad_graph = pass::MXGradient(*fwd_graph, - fwd_graph->outputs, - xs, - *ograd_entries, - mxnet::AggregateGradient, - nullptr, - zero_ops, - "_copy"); - } catch (const nnvm::pass::InvalidGraphError& e) { + *grad_graph = pass::MXGradient( + *fwd_graph, fwd_graph->outputs, xs, *ograd_entries, + mxnet::AggregateGradient, nullptr, + zero_ops, "_copy"); + } catch (const nnvm::pass::InvalidGraphError &e) { *grad_graph = nnvm::Graph(); } } else { @@ -279,27 +278,25 @@ void CreateFullGraph(const nnvm::Symbol& sym, *fwd_graph = exec::EliminateCommonExpr(std::move(*fwd_graph)); // construct backward graph - CreateBackwardGraph(fwd_graph, grad_graph, ograd_entries, fwd_input_to_grad_output); + CreateBackwardGraph(fwd_graph, grad_graph, ograd_entries, + fwd_input_to_grad_output); full_graph->outputs = fwd_graph->outputs; // add backward graph outputs to full graph - for (const auto& i : grad_graph->outputs) { + for (const auto &i : grad_graph->outputs) { full_graph->outputs.emplace_back(i); } } /* \brief Set Ref counts for node entries for forward graph */ -void SetForwardRefCounts(nnvm::Graph* fwd_graph) { +void SetForwardRefCounts(nnvm::Graph *fwd_graph) { const auto& idx = fwd_graph->indexed_graph(); std::vector ref_count(idx.num_node_entries(), 0); - for (const auto& i : idx.input_nodes()) - ++ref_count[idx.entry_id(i, 0)]; - for (const auto& i : idx.outputs()) - ++ref_count[idx.entry_id(i)]; + for (const auto& i : idx.input_nodes()) ++ref_count[idx.entry_id(i, 0)]; + for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; for (size_t i = 0; i < idx.num_nodes(); ++i) { - for (const auto& j : idx[i].inputs) - ++ref_count[idx.entry_id(j)]; + for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; } fwd_graph->attrs[AddPrefix(FORWARD, REF_COUNT)] = @@ -311,7 +308,7 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) { const auto& idx = fwd_graph->indexed_graph(); SetForwardRefCounts(fwd_graph); - size_t num_forward_nodes = idx.num_nodes(); + size_t num_forward_nodes = idx.num_nodes(); size_t num_forward_entries = idx.num_node_entries(); const auto& full_idx = full_graph.indexed_graph(); @@ -319,39 +316,38 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) { std::vector temp_ref_count(full_idx.num_node_entries(), 0); for (size_t i = num_forward_nodes; i < full_idx.num_nodes(); ++i) { for (const auto& j : full_idx[i].inputs) { - ++temp_ref_count[full_idx.entry_id(j)]; + ++temp_ref_count[full_idx.entry_id(j)]; } } - auto full_ref_count = fwd_graph->GetAttr>(AddPrefix(FORWARD, REF_COUNT)); - for (size_t i = 0; i < num_forward_entries; ++i) - full_ref_count.at(i) += temp_ref_count[i]; + auto full_ref_count = fwd_graph->GetAttr >(AddPrefix(FORWARD, + REF_COUNT)); + for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += temp_ref_count[i]; fwd_graph->attrs[AddPrefix(FULL, REF_COUNT)] = std::make_shared(std::move(full_ref_count)); } -void OptimizeGraph(nnvm::Graph* full_graph, - nnvm::Graph* fwd_graph, - nnvm::Graph* grad_graph, - std::vector* input_map, - const Context& context, - size_t num_forward_outputs, - const bool inlining) { +void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph* grad_graph, + std::vector* input_map, const Context& context, + size_t num_forward_outputs, const bool inlining) { input_map->resize(full_graph->indexed_graph().input_nodes().size()); std::iota(input_map->begin(), input_map->end(), 0); #if MXNET_USE_CUDA && !defined(_WIN32) - if (context.dev_mask() == kGPU && !inlining && dmlc::GetEnv("MXNET_USE_FUSION", true)) { + if (context.dev_mask() == kGPU && + !inlining && + dmlc::GetEnv("MXNET_USE_FUSION", true)) { nnvm::Graph unoptimized_graph; common::CopyGraph(&unoptimized_graph, *full_graph, false); if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) { *full_graph = exec::FusePointwise(*full_graph, num_forward_outputs); // Fill in input_map - mapping from the new to the original input indices. - const auto& original_inputs = unoptimized_graph.indexed_graph().input_nodes(); - const auto& new_inputs = full_graph->indexed_graph().input_nodes(); + const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes(); + const auto &new_inputs = full_graph->indexed_graph().input_nodes(); if (original_inputs.size() != new_inputs.size()) { - LOG(WARNING) << "Number of inputs after fusion does not match original number of inputs. " - << "This is most probably a bug. Disabling fusion for this run."; + LOG(WARNING) + << "Number of inputs after fusion does not match original number of inputs. " + << "This is most probably a bug. Disabling fusion for this run."; *full_graph = unoptimized_graph; } else { std::unordered_map original_input_map; @@ -369,22 +365,25 @@ void OptimizeGraph(nnvm::Graph* full_graph, } } else { LOG(WARNING) - << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; - } + << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; + } } #else // Only warn user if MXNET_USE_FUSION env var is explicitly set - if (context.dev_mask() == kGPU && !inlining && dmlc::GetEnv("MXNET_USE_FUSION", false)) { + if (context.dev_mask() == kGPU && !inlining && + dmlc::GetEnv("MXNET_USE_FUSION", false)) { exec::WarnFusionNotSupported(); } #endif // MXNET_USE_CUDA && !defined(_WIN32) - *fwd_graph = nnvm::Graph(); - fwd_graph->outputs = std::vector( - full_graph->outputs.begin(), full_graph->outputs.begin() + num_forward_outputs); - *grad_graph = nnvm::Graph(); - grad_graph->outputs = std::vector( - full_graph->outputs.begin() + num_forward_outputs, full_graph->outputs.end()); + *fwd_graph = nnvm::Graph(); + fwd_graph->outputs = std::vector(full_graph->outputs.begin(), + full_graph->outputs.begin() + + num_forward_outputs); + *grad_graph = nnvm::Graph(); + grad_graph->outputs = std::vector(full_graph->outputs.begin() + + num_forward_outputs, + full_graph->outputs.end()); SetRefCounts(fwd_graph, *full_graph); } @@ -463,10 +462,13 @@ class LazyTransformDataset; } class CachedOp { - using CachedOpMonCallback = std::function; + using CachedOpMonCallback = + std::function; public: - CachedOp(const nnvm::Symbol& sym, const std::vector>& flags); + CachedOp( + const nnvm::Symbol& sym, + const std::vector >& flags); virtual ~CachedOp(); nnvm::Symbol GetOptimizedSymbol() const; uint32_t num_inputs() const { @@ -479,7 +481,7 @@ class CachedOp { return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); } uint32_t num_backward_outputs() const { - auto& idx = fwd_graph_.indexed_graph(); + auto &idx = fwd_graph_.indexed_graph(); return idx.input_nodes().size() - idx.mutable_input_nodes().size(); } std::vector& save_inputs() { @@ -491,23 +493,30 @@ class CachedOp { const std::unordered_set& mutable_input_nodes() const { return fwd_graph_.indexed_graph().mutable_input_nodes(); } - virtual std::vector Gradient(const nnvm::ObjectPtr& node, - const std::vector& ograds) const; - virtual OpStatePtr Forward(const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_context); - virtual void Backward(const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); + void set_nleafs(const std::vector& nleafs) { + nleafs_ = nleafs; + } + virtual std::vector Gradient( + const nnvm::ObjectPtr& node, + const std::vector& ograds) const; + virtual OpStatePtr Forward( + const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs, + const Context &default_context); + virtual void Backward( + const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); // backward storage type inference - virtual bool BackwardStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_attrs, - std::vector* out_attrs); + virtual bool BackwardStorageType( + const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs); std::vector ListForwardInputNames() const { nnvm::Symbol sym = GetForwardSym(); return sym.ListInputNames(nnvm::Symbol::kAll); @@ -521,7 +530,8 @@ class CachedOp { sym.outputs = fwd_graph_.outputs; return sym; } - void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, bool monitor_all = false); + void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, + bool monitor_all = false); protected: struct GraphInfo { @@ -536,32 +546,23 @@ class CachedOp { }; struct CachedOpState { - CachedOpState(const Context& context_, - const nnvm::Graph& fwd_graph_, - const nnvm::Graph& full_graph_, - const bool inlining_) { + CachedOpState(const Context &context_, const nnvm::Graph &fwd_graph_, + const nnvm::Graph &full_graph_, const bool inlining_) { context = context_; nnvm::Symbol sym; sym.outputs = fwd_graph_.outputs; - CreateFullGraph(sym.Copy(), - &info.fwd_graph, - &info.grad_graph, - &info.full_graph, - &info.ograd_entries, + CreateFullGraph(sym.Copy(), &info.fwd_graph, &info.grad_graph, + &info.full_graph, &info.ograd_entries, &info.fwd_input_to_grad_output); - OptimizeGraph(&info.full_graph, - &info.fwd_graph, - &info.grad_graph, - &info.input_map, - context_, - fwd_graph_.outputs.size(), - inlining_); - - size_t max_nodes = info.full_graph.indexed_graph().num_nodes(); - size_t max_entries = info.full_graph.indexed_graph().num_node_entries(); - info.fwd_graph.attrs["context"] = std::make_shared( - std::vector(info.fwd_graph.indexed_graph().num_nodes(), context)); + OptimizeGraph(&info.full_graph, &info.fwd_graph, &info.grad_graph, &info.input_map, + context_, fwd_graph_.outputs.size(), inlining_); + + size_t max_nodes = info.full_graph.indexed_graph().num_nodes(); + size_t max_entries = info.full_graph.indexed_graph().num_node_entries(); + info.fwd_graph.attrs["context"] = + std::make_shared(std::vector( + info.fwd_graph.indexed_graph().num_nodes(), context)); info.full_graph.attrs["context"] = std::make_shared(std::vector(max_nodes, context)); @@ -578,15 +579,15 @@ class CachedOp { Context context; GraphInfo info; - bool recording = false; - bool fwd_alloc = false; - bool bwd_alloc = false; + bool recording = false; + bool fwd_alloc = false; + bool bwd_alloc = false; bool fwd_exec_init = false; bool bwd_exec_init = false; std::vector buff; - std::vector arrays; - std::vector arrays_with_in_out; + std::vector arrays; + std::vector arrays_with_in_out; std::vector array_reqs; std::vector op_states; @@ -599,45 +600,59 @@ class CachedOp { }; OpStatePtr GetCachedOpState(const Context& ctx); - bool SetForwardGraph(const Context& default_ctx, - GraphInfo* info, - const bool recording, - const std::vector& inputs); - bool SetBackwardGraph(GraphInfo* info, - const std::vector& reqs, - const std::vector& inputs, - bool detect_inplace_addto = false); - bool CheckDynamicShapeExists(const Context& default_ctx, - const std::vector& inputs, - bool erase_result); - void StaticAllocMemory(const OpStatePtr& state_ptr, bool recording, bool keep_fwd); - void StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool keep_fwd); - void StaticRunOps(const Context& default_ctx, - const nnvm::Graph& g, - const OpStatePtr& state_ptr, - const std::vector& state_arrays, - size_t start_nid, - size_t end_nid); - OpStatePtr StaticForward(const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs); + bool SetForwardGraph( + const Context& default_ctx, + GraphInfo* info, + const bool recording, + const std::vector& inputs); + bool SetBackwardGraph( + GraphInfo* info, + const std::vector& reqs, + const std::vector& inputs, + bool detect_inplace_addto = false); + bool CheckDynamicShapeExists( + const Context& default_ctx, + const std::vector& inputs, + bool erase_result); + void StaticAllocMemory( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd); + void StaticInitExec( + const OpStatePtr& state_ptr, + bool recording, + bool keep_fwd); + void StaticRunOps( + const Context& default_ctx, + const nnvm::Graph& g, + const OpStatePtr& state_ptr, + const std::vector &state_arrays, + size_t start_nid, + size_t end_nid); + OpStatePtr StaticForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs); struct DynamicRuntime; private: - OpStatePtr DynamicForward(const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs, - bool use_naive_run = false); - void DynamicBackward(const bool retain_graph, - const OpStatePtr& op_state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); - void StaticBackward(const bool retain_graph, - const OpStatePtr& state_ptr, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); + OpStatePtr DynamicForward( + const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs, + bool use_naive_run = false); + void DynamicBackward( + const bool retain_graph, + const OpStatePtr& op_state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); + void StaticBackward( + const bool retain_graph, + const OpStatePtr& state_ptr, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); size_t BwdOriginalInput(const std::vector& input_map, size_t new_i); CachedOpConfig config_; @@ -649,16 +664,17 @@ class CachedOp { std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; std::vector save_inputs_, save_outputs_; std::vector bwd_output_reqs_; + std::vector nleafs_; std::function monitor_callback_{nullptr}; bool monitor_all_{false}; std::mutex mutex_; - std::unordered_map> cached_op_states_; + std::unordered_map > cached_op_states_; friend class ::mxnet::io::LazyTransformDataset; nnvm::Symbol sym_; - std::vector> flags_; + std::vector > flags_; }; struct CachedOp::DynamicRuntime { diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 73b87106f80a..a51e3b944c8c 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -163,6 +163,18 @@ void Imperative::MarkVariables( } } +void Imperative::MarkDCVariables(const std::vector& nleafs, int cnt_vars) { + for (NDArray * nleaf : nleafs) { + if (Imperative::DCInfo::IsNone(*nleaf)) { + LOG(WARNING) << "The marked node doesn't have deferred compute history."; + } else { + nnvm::ObjectPtr node = nleaf->deferredcompute_entry_.node; + node->attrs.dict["mark_id"] = std::to_string(cnt_vars); + } + cnt_vars++; + } +} + // Unmark the variables to free the memory. void Imperative::DropGrads(const std::vector& variables) { for (auto variable : variables) { diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index e3a58804d8ac..456dd8f7e142 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -27,7 +27,7 @@ std::vector NodeInputs(const nnvm::IndexedGraph& idx, const int node_idx, const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; - const size_t num_inputs = node.inputs.size(); + const size_t num_inputs = node.inputs.size(); std::vector ndinputs; ndinputs.reserve(num_inputs); for (const auto& j : node.inputs) { @@ -41,7 +41,7 @@ std::vector NodeOutputs(const nnvm::IndexedGraph& idx, const int node_idx, const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; - const size_t num_outputs = node.source->num_outputs(); + const size_t num_outputs = node.source->num_outputs(); std::vector ndoutputs; ndoutputs.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { @@ -55,7 +55,7 @@ std::vector NodeReq(const nnvm::IndexedGraph& idx, const int node_idx, const std::vector& array_reqs) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; - const size_t num_outputs = node.source->num_outputs(); + const size_t num_outputs = node.source->num_outputs(); std::vector req; req.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { @@ -126,19 +126,20 @@ void InvokeOperator(const nnvm::IndexedGraph& idx, namespace mxnet { namespace imperative { -void RunGraph(const bool retain_graph, - const nnvm::IndexedGraph& idx, - const std::vector& arrays, - size_t node_start, - size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector* p_states, - const DispatchModeVector& dispatch_modes, - bool recording, - mxnet::ShapeVector* shapes, - const imperative::CachedOpMonCallback& callback, - const bool monitor_all) { +void RunGraph( + const bool retain_graph, + const nnvm::IndexedGraph& idx, + const std::vector& arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector &dispatch_modes, + bool recording, + mxnet::ShapeVector *shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all, + const std::vector& nleafs) { CHECK(shapes == nullptr); for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; @@ -161,29 +162,39 @@ void RunGraph(const bool retain_graph, Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state); } }; - InvokeOperator( - idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); + InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, + &req, &ref_count, invoke); if (callback) { - mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + } + // set the autograd_entry_ in marked nleafs + if (nleafs.size()) { + auto it = node.source->attrs.dict.find("mark_id"); + if (it != node.source->attrs.dict.end()) { + int mark_id = std::stoi(it->second); + CHECK_LT(mark_id, nleafs.size()) + << "Mark_id exceeds the nonleaf list size."; + nleafs[mark_id]->copy_autograd_entry_(ndoutputs[0]); + } } } } -void NaiveRunGraph(const bool retain_graph, - const Context& default_ctx, - const nnvm::IndexedGraph& idx, - const std::vector& arrays, - size_t node_start, - size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector* p_states, - const DispatchModeVector& dispatch_modes, - bool recording, - mxnet::ShapeVector* shapes, - const imperative::CachedOpMonCallback& callback, - const bool monitor_all, - const bool skip_engine) { +void NaiveRunGraph( + const bool retain_graph, + const Context& default_ctx, + const nnvm::IndexedGraph& idx, + const std::vector& arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector &dispatch_modes, + bool recording, + mxnet::ShapeVector *shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all, + const bool skip_engine) { for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; if (node.source->op() == nullptr) { diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 7f90528f4793..0132eecefc3f 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -137,21 +137,23 @@ struct EngineOprSeg { std::unique_ptr opr; }; -using MemoryPlanVector = std::vector; +using MemoryPlanVector = std::vector; using CachedOpMonCallback = std::function; inline Context GetContext(const nnvm::NodeAttrs& attrs, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_ctx) { + const std::vector& inputs, + const std::vector& outputs, + const Context& default_ctx) { Context ctx; if (inputs.size()) { ctx = inputs[0]->ctx(); for (size_t i = 1; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx().dev_mask(), ctx.dev_mask()) - << "Operator " << attrs.op->name << " require all inputs live on the same context. " - << "But the first argument is on " << ctx << " while the " << i + 1 - << "-th argument is on " << inputs[i]->ctx(); + << "Operator " << attrs.op->name + << " require all inputs live on the same context. " + << "But the first argument is on " + << ctx << " while the " << i+1 << "-th argument is on " + << inputs[i]->ctx(); } } else if (outputs.size() && !outputs[0]->is_none()) { ctx = outputs[0]->ctx(); @@ -184,12 +186,12 @@ inline void SetShapeType(const Context& ctx, const std::vector& inputs, const std::vector& outputs, DispatchMode* dispatch_mode) { - static auto& infershape = nnvm::Op::GetAttr("FInferShape"); - static auto& infertype = nnvm::Op::GetAttr("FInferType"); - static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); - MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); + static auto& infershape = nnvm::Op::GetAttr("FInferShape"); + static auto& infertype = nnvm::Op::GetAttr("FInferType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); + MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); // infer shape - mxnet::ShapeVector& in_shapes = ret->arg_shapes; + mxnet::ShapeVector& in_shapes = ret->arg_shapes; in_shapes.clear(); in_shapes.reserve(inputs.size()); for (auto& i : inputs) { @@ -205,7 +207,7 @@ inline void SetShapeType(const Context& ctx, if (!is_dynamic_shape_existing) { // If any of the inputs is a deferred computed array with unknown shape, we // can't infer shapes. - for (const NDArray* i : inputs) { + for (const NDArray *i : inputs) { if (!shape_is_known(i->shape()) && !Imperative::DCInfo::IsNone(*i)) { is_dynamic_shape_existing = true; break; @@ -275,16 +277,15 @@ inline void SetShapeType(const Context& ctx, } bool infer_stype_success = false; if (inferstorage.count(attrs.op)) { - infer_stype_success = inferstorage[attrs.op]( - attrs, ctx.dev_mask(), dispatch_mode, &in_storage_types, &out_storage_types); + infer_stype_success = inferstorage[attrs.op](attrs, ctx.dev_mask(), dispatch_mode, + &in_storage_types, &out_storage_types); } else { // if infer storage attr is not present, apply the default infer storage function - infer_stype_success = common::DefaultStorageType( - attrs, ctx.dev_mask(), dispatch_mode, &in_storage_types, &out_storage_types); + infer_stype_success = common::DefaultStorageType(attrs, ctx.dev_mask(), dispatch_mode, + &in_storage_types, &out_storage_types); } CHECK(infer_stype_success) << "Operator not implemented: " - << common::operator_stype_string( - attrs, ctx.dev_mask(), in_storage_types, out_storage_types); + << common::operator_stype_string(attrs, ctx.dev_mask(), in_storage_types, out_storage_types); if (*dispatch_mode == DispatchMode::kFComputeFallback) { common::LogStorageFallback(attrs, ctx.dev_mask(), &in_storage_types, &out_storage_types); } @@ -293,12 +294,12 @@ inline void SetShapeType(const Context& ctx, CHECK(*dispatch_mode != DispatchMode::kUndefined); for (size_t i = 0; i < outputs.size(); ++i) { if (outputs[i]->is_none() || (mxnet::op::shape_is_none(outputs[i]->shape()) && - Imperative::DCInfo::IsNone(*outputs[i]))) { + Imperative::DCInfo::IsNone(*outputs[i]))) { if (!is_dynamic_shape_existing) { const auto storage_type = static_cast(out_storage_types[i]); outputs[i]->ReInit(storage_type, out_shapes[i], ctx, out_types[i]); } else { - *outputs[i] = NDArray(ctx, out_types[i]); + *outputs[i] = NDArray(ctx, out_types[i]); } outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); } else if (mxnet::op::shape_is_none(outputs[i]->shape())) { @@ -309,18 +310,18 @@ inline void SetShapeType(const Context& ctx, outputs[i]->Init(out_shapes[i]); } CHECK_EQ(outputs[i]->dtype(), out_types[i]) - << i << "-th output has invalid dtype. " - << "Expecting " << out_types[i] << " got " << outputs[i]->dtype() << " in operator " - << attrs.op->name; + << i << "-th output has invalid dtype. " + << "Expecting " << out_types[i] << " got " << outputs[i]->dtype() + << " in operator " << attrs.op->name; } else { CHECK_EQ(outputs[i]->shape(), out_shapes[i]) - << i << "-th output has invalid shape. " - << "Expecting " << out_shapes[i] << " got " << outputs[i]->shape() << " in operator " - << attrs.op->name; + << i << "-th output has invalid shape. " + << "Expecting " << out_shapes[i] << " got " + << outputs[i]->shape() << " in operator " << attrs.op->name; CHECK_EQ(outputs[i]->dtype(), out_types[i]) - << i << "-th output has invalid dtype. " - << "Expecting " << out_types[i] << " got " << outputs[i]->dtype() << " in operator " - << attrs.op->name; + << i << "-th output has invalid dtype. " + << "Expecting " << out_types[i] << " got " + << outputs[i]->dtype() << " in operator " << attrs.op->name; } } } @@ -330,53 +331,53 @@ inline void SetShapeType(const Context& ctx, * For inputs and outputs arguments only NDArray::var() is accessed. */ inline void SetDependency(const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& inputs, - const std::vector& outputs, - std::vector* p_read_vars, - std::vector* p_write_vars, - std::vector* p_requested, - std::vector* p_mutate_idx, - const DispatchMode dispatch_mode) { - static auto& fmutate = nnvm::Op::GetAttr("FMutateInputs"); - static auto& ftmp_resource = nnvm::Op::GetAttr("FResourceRequest"); + const Context& ctx, + const std::vector& inputs, + const std::vector& outputs, + std::vector *p_read_vars, + std::vector *p_write_vars, + std::vector *p_requested, + std::vector *p_mutate_idx, + const DispatchMode dispatch_mode) { + static auto& fmutate = nnvm::Op::GetAttr("FMutateInputs"); + static auto& ftmp_resource = nnvm::Op::GetAttr("FResourceRequest"); static auto& ftmp_resource_ex = nnvm::Op::GetAttr("FResourceRequestEx"); std::vector& read_vars = *p_read_vars; std::vector& write_vars = *p_write_vars; - std::vector& requested = *p_requested; - std::vector& mutate_idx = *p_mutate_idx; + std::vector& requested = *p_requested; + std::vector& mutate_idx = *p_mutate_idx; if (fmutate.count(attrs.op)) { mutate_idx = fmutate[attrs.op](attrs); } - const bool rsc_req = (ftmp_resource.count(attrs.op) != 0); + const bool rsc_req = (ftmp_resource.count(attrs.op) != 0); const bool rsc_ex_req = (ftmp_resource_ex.count(attrs.op) != 0); if (rsc_req || rsc_ex_req) { - int ntmp = 0; - auto resource_reqs = rsc_ex_req ? ftmp_resource_ex[attrs.op]( - attrs, static_cast(ctx.dev_mask()), dispatch_mode) : - ftmp_resource[attrs.op](attrs); + int ntmp = 0; + auto resource_reqs = rsc_ex_req ? ftmp_resource_ex[attrs.op](attrs, + static_cast(ctx.dev_mask()), dispatch_mode) + : ftmp_resource[attrs.op](attrs); for (const auto& req : resource_reqs) { switch (req.type) { - case ResourceRequest::kTempSpace: - ++ntmp; - case ResourceRequest::kRandom: - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - write_vars.push_back(requested.back().var); - break; - case ResourceRequest::kParallelRandom: - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - write_vars.push_back(requested.back().var); - break; + case ResourceRequest::kTempSpace: + ++ntmp; + case ResourceRequest::kRandom: + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + write_vars.push_back(requested.back().var); + break; + case ResourceRequest::kParallelRandom: + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + write_vars.push_back(requested.back().var); + break; #if MXNET_USE_CUDNN == 1 - case ResourceRequest::kCuDNNDropoutDesc: - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - write_vars.push_back(requested.back().var); - break; + case ResourceRequest::kCuDNNDropoutDesc: + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + write_vars.push_back(requested.back().var); + break; #endif // MXNET_USE_CUDNN == 1 - default: - LOG(FATAL) << "resource type not yet supported"; + default: + LOG(FATAL) << "resource type not yet supported"; } } CHECK_LE(ntmp, 1) << "Only support 1 temp space request"; @@ -396,7 +397,7 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs, for (auto& i : outputs) { write_vars.push_back(i->var()); } - for (auto& i : mutate_idx) { + for (auto & i : mutate_idx) { write_vars.push_back(inputs[i]->var()); } Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars); @@ -408,11 +409,11 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs, * NDArray. Set to kWriteTo otherwise. */ inline void SetWriteInplaceReq(const std::vector& inputs, - const std::vector& outputs, - std::vector* req) { + const std::vector& outputs, + std::vector *req) { std::unordered_set in_vars; in_vars.reserve(inputs.size()); - for (auto& i : inputs) { + for (auto &i : inputs) { in_vars.insert(i->var()); } req->clear(); @@ -434,16 +435,16 @@ inline void SetWriteInplaceReq(const std::vector& inputs, * \param param_vals Array of string pointers representing the associated values * \return nnvm::NodeAttrs structure representing the parsed attributes */ -inline nnvm::NodeAttrs ParseAttrs(const nnvm::Op* op, +inline nnvm::NodeAttrs ParseAttrs(const nnvm::Op *op, const int num_inputs, const int num_params, - const char** param_keys, - const char** param_vals) { + const char **param_keys, + const char **param_vals) { static auto& num_args = nnvm::Op::GetAttr("key_var_num_args"); nnvm::NodeAttrs attrs; attrs.op = op; - attrs.dict.reserve(num_params + 1); + attrs.dict.reserve(num_params+1); for (int i = 0; i < num_params; ++i) { attrs.dict.emplace(param_keys[i], param_vals[i]); } @@ -465,7 +466,7 @@ inline nnvm::NodeAttrs ParseAttrs(const nnvm::Op* op, * \param infered_num_outputs The inferred number of outputs * \param num_visible_outputs The actual number of visible outputs */ -inline void SetNumOutputs(const nnvm::Op* op, +inline void SetNumOutputs(const nnvm::Op *op, const nnvm::NodeAttrs& attrs, const int& num_inputs, int* infered_num_outputs, @@ -478,8 +479,8 @@ inline void SetNumOutputs(const nnvm::Op* op, infered_num_inputs = op->num_inputs; } CHECK_EQ(num_inputs, infered_num_inputs) - << "Operator " << op->name << " expects " << infered_num_inputs << " inputs, but got " - << num_inputs << " instead."; + << "Operator " << op->name << " expects " << infered_num_inputs + << " inputs, but got " << num_inputs << " instead."; if (op->get_num_outputs != nullptr) { *infered_num_outputs = op->get_num_outputs(attrs); } else { @@ -495,36 +496,30 @@ inline void SetNumOutputs(const nnvm::Op* op, /*! * \brief Copy-construct NDArrays referenced by inputs and outputs to p_inputs and p_outputs */ -inline void DerefInputOutput(const std::vector& inputs, - const std::vector& outputs, +inline void DerefInputOutput(const std::vector& inputs, + const std::vector& outputs, std::vector* p_inputs, std::vector* p_outputs) { p_inputs->reserve(inputs.size()); p_outputs->reserve(outputs.size()); - for (const auto i : inputs) - p_inputs->emplace_back(*i); - for (const auto i : outputs) - p_outputs->emplace_back(*i); + for (const auto i : inputs) p_inputs->emplace_back(*i); + for (const auto i : outputs) p_outputs->emplace_back(*i); } inline void DerefInputOutput(const std::vector& inputs, const std::vector& outputs, - std::vector* p_inputs, - std::vector* p_outputs) { + std::vector* p_inputs, + std::vector* p_outputs) { p_inputs->reserve(inputs.size()); p_outputs->reserve(outputs.size()); - for (const auto i : inputs) - p_inputs->emplace_back(new NDArray(*i)); - for (const auto i : outputs) - p_outputs->emplace_back(new NDArray(*i)); + for (const auto i : inputs) p_inputs->emplace_back(new NDArray(*i)); + for (const auto i : outputs) p_outputs->emplace_back(new NDArray(*i)); } -inline void DerefInputOutputRelease(const std::vector& inputs, - const std::vector& outputs) { - for (auto i : inputs) - delete i; - for (auto i : outputs) - delete i; +inline void DerefInputOutputRelease(const std::vector& inputs, + const std::vector& outputs) { + for (auto i : inputs) delete i; + for (auto i : outputs) delete i; } /* @@ -540,19 +535,19 @@ inline void DerefInputOutputRelease(const std::vector& inputs, indices are not recorded * \return true if any source NDArray need to cast storage */ -inline bool SetupDefaultBlobsIn(const std::vector& src, - const std::vector* bufs, - std::vector* blobs, - std::vector* temp_src, - std::vector* temp_dst, - std::unordered_map* idx_map) { +inline bool SetupDefaultBlobsIn(const std::vector& src, + const std::vector *bufs, + std::vector *blobs, + std::vector *temp_src, + std::vector *temp_dst, + std::unordered_map *idx_map) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { const auto& nd = *src[i]; if (!DEFAULT_DATA(nd)) { (*idx_map)[i] = temp_dst->size(); - NDArray temp = - bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); + NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), + true, nd.dtype()); #if MXNET_USE_ONEDNN == 1 CHECK(temp.IsDefaultData()); #endif @@ -567,12 +562,12 @@ inline bool SetupDefaultBlobsIn(const std::vector& src, return require_cast; } -inline bool SetupDefaultBlobsOut(const std::vector& src, - const std::vector* bufs, - std::vector* req, - std::vector* blobs, - std::vector* temp_src, - std::vector* temp_dst) { +inline bool SetupDefaultBlobsOut(const std::vector& src, + const std::vector *bufs, + std::vector *req, + std::vector *blobs, + std::vector *temp_src, + std::vector *temp_dst) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { const auto& nd = *src[i]; @@ -597,8 +592,8 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, } CHECK(temp.IsDefaultData()); #else - NDArray temp = - bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); + NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), + true, nd.dtype()); #endif temp_src->emplace_back(nd); temp_dst->emplace_back(temp); @@ -618,23 +613,25 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, * function also records the indices of non-default source NDArrays and the indices of * their corresponding temporary NDArrays in the temp array. */ -inline void SetupDefaultBlobsInOut(const std::vector& ndinputs, - const std::vector& ndoutputs, - const std::vector* in_bufs, - const std::vector* out_bufs, - std::vector* req, - std::vector* input_blobs, - std::vector* output_blobs, - std::vector* pre_temp_src, - std::vector* pre_temp_dst, - std::vector* post_temp_src, - std::vector* post_temp_dst, - std::unordered_map* in_temp_idx_map, - const std::vector& mutate_idx) { +inline void SetupDefaultBlobsInOut(const std::vector &ndinputs, + const std::vector &ndoutputs, + const std::vector *in_bufs, + const std::vector *out_bufs, + std::vector *req, + std::vector *input_blobs, + std::vector *output_blobs, + std::vector *pre_temp_src, + std::vector *pre_temp_dst, + std::vector *post_temp_src, + std::vector *post_temp_dst, + std::unordered_map *in_temp_idx_map, + const std::vector &mutate_idx) { // populate input blobs - SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, in_temp_idx_map); + SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, + in_temp_idx_map); // populate output blobs - SetupDefaultBlobsOut(ndoutputs, out_bufs, req, output_blobs, post_temp_dst, post_temp_src); + SetupDefaultBlobsOut(ndoutputs, out_bufs, req, output_blobs, post_temp_dst, + post_temp_src); // add mutable inputs to post temp list for (const auto idx : mutate_idx) { auto map_iter = in_temp_idx_map->find(idx); @@ -645,30 +642,30 @@ inline void SetupDefaultBlobsInOut(const std::vector& ndinputs, } } -#define REDEFINE_INPUTS_OUTPUTS(in, out, newIn, newOut) \ - std::vector newIn, newOut; \ - DerefInputOutput(in, out, &newIn, &newOut); \ - DerefInputOutputRelease(in, out) +#define REDEFINE_INPUTS_OUTPUTS(in, out, newIn, newOut) \ + std::vector newIn, newOut; \ + DerefInputOutput(in, out, &newIn, &newOut); \ + DerefInputOutputRelease(in, out) inline void PushFCompute(const FCompute& fn, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& read_vars, - const std::vector& write_vars, - const std::vector& requested, - const std::vector& p_inputs, - const std::vector& p_outputs, - const std::vector& mutate_idx, - const std::vector& req) { + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& p_inputs, + const std::vector& p_outputs, + const std::vector& mutate_idx, + const std::vector& req) { using namespace common; static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - bool is_train = Imperative::Get()->is_training(); - bool need_grad = Imperative::Get()->is_recording(); + bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; CHECK(exec_type == ExecType::kSync); - std::vector inputs, outputs; + std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { std::vector input_blobs, output_blobs; @@ -679,19 +676,9 @@ inline void PushFCompute(const FCompute& fn, INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy, outputs, req); std::vector tmp_req = req; // setup blobs - SetupDefaultBlobsInOut(inputs, - outputs, - nullptr, - nullptr, - &tmp_req, - &input_blobs, - &output_blobs, - &pre_temp_src, - &pre_temp_dst, - &post_temp_src, - &post_temp_dst, - &in_temp_idx_map, - mutate_idx); + SetupDefaultBlobsInOut(inputs, outputs, nullptr, nullptr, &tmp_req, + &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, + &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); // setup context OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; bool is_gpu = ctx.dev_mask() == gpu::kDevMask; @@ -707,63 +694,67 @@ inline void PushFCompute(const FCompute& fn, run(RunContext{ctx, nullptr, nullptr}); } else { Engine::Get()->PushSync( - run, ctx, read_vars, write_vars, FnProperty::kNormal, 0, op->name.c_str()); + run, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, op->name.c_str()); } } inline void PushFComputeEx(const FComputeEx& fn, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& read_vars, - const std::vector& write_vars, - const std::vector& requested, - const std::vector& p_inputs, - const std::vector& p_outputs, - const std::vector& req) { + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& p_inputs, + const std::vector& p_outputs, + const std::vector& req) { static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - const bool is_train = Imperative::Get()->is_training(); - const bool need_grad = Imperative::Get()->is_recording(); - const auto exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; + const bool is_train = Imperative::Get()->is_training(); + const bool need_grad = Imperative::Get()->is_recording(); + const auto exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; const auto cross_device_copy = exec_type == ExecType::kCrossDeviceCopy; - std::vector inputs, outputs; + std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { - OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; - REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA); - INVALIDATE_OUTPUTS_COND(!cross_device_copy, outputsA, req); - CREATE_DEFAULT_INPUTS(!cross_device_copy, attrs, CreateDefaultInputs(&inputsA)); - fn(attrs, opctx, inputsA, req, outputsA); - }; + OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; + REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA); + INVALIDATE_OUTPUTS_COND(!cross_device_copy, outputsA, req); + CREATE_DEFAULT_INPUTS(!cross_device_copy, attrs, CreateDefaultInputs(&inputsA)); + fn(attrs, opctx, inputsA, req, outputsA); + if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) { + rctx.get_stream()->Wait(); + } + }; if (cross_device_copy || CheckIfSkipEngine(attrs)) { run(RunContext{ctx, nullptr, nullptr}); } else { CHECK(exec_type == ExecType::kSync); - Engine::Get()->PushSync( - run, ctx, read_vars, write_vars, FnProperty::kNormal, 0, op->name.c_str()); + Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, op->name.c_str()); } } inline void PushOperator(const OpStatePtr& state, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& read_vars, - const std::vector& write_vars, - const std::vector& requested, - const std::vector& p_inputs, - const std::vector& p_outputs, - const std::vector& mutate_idx, - const std::vector& req, - const DispatchMode dispatch_mode) { + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& p_inputs, + const std::vector& p_outputs, + const std::vector& mutate_idx, + const std::vector& req, + const DispatchMode dispatch_mode) { using namespace common; static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - bool is_train = Imperative::Get()->is_training(); - bool need_grad = Imperative::Get()->is_recording(); + bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; - std::vector inputs, outputs; + std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); auto fcompute_ex = common::GetFCompute(op, "FStatefulComputeEx", ctx); @@ -773,12 +764,10 @@ inline void PushOperator(const OpStatePtr& state, engine::CallbackOnComplete on_complete) { OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA); - INVALIDATE_OUTPUTS_COND( - exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp", outputsA, req); + INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp", + outputsA, req); CREATE_DEFAULT_INPUTS(exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp", - attrs, - CreateDefaultInputs(&inputsA)); - on_start(); + attrs, CreateDefaultInputs(&inputsA)); fcompute_ex(state, opctx, inputsA, req, outputsA); }; @@ -789,19 +778,14 @@ inline void PushOperator(const OpStatePtr& state, run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( - [=](RunContext rctx) { - run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); - }, - ctx, - read_vars, - write_vars, - FnProperty::kNormal, - 0, + [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, + ctx, read_vars, write_vars, FnProperty::kNormal, 0, op->name.c_str()); } else { CHECK(exec_type == ExecType::kAsync); - Engine::Get()->PushAsync( - run, ctx, read_vars, write_vars, FnProperty::kAsync, 0, op->name.c_str()); + Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, + FnProperty::kAsync, 0, + op->name.c_str()); } } else { auto fcompute = common::GetFCompute(op, "FStatefulCompute", ctx); @@ -809,42 +793,34 @@ inline void PushOperator(const OpStatePtr& state, << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; - const auto& run = [=](RunContext rctx, - engine::CallbackOnStart on_start, - engine::CallbackOnComplete on_complete) { - OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; - - std::vector input_blobs, output_blobs; - // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays - std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; - // mapping from index in input_blobs to index in pre_temp_dst - std::unordered_map in_temp_idx_map; - INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy, outputs, req); - - std::vector tmp_req = req; - // populate input blobs and output blobs - SetupDefaultBlobsInOut(inputs, - outputs, - nullptr, - nullptr, - &tmp_req, - &input_blobs, - &output_blobs, - &pre_temp_src, - &pre_temp_dst, - &post_temp_src, - &post_temp_dst, - &in_temp_idx_map, - mutate_idx); - // setup contexts - const bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask; - // pre-fcompute fallback - CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu); - fcompute(state, opctx, input_blobs, tmp_req, output_blobs); - // post-fcompute fallback, cast to original storage type, if necessary - CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - DerefInputOutputRelease(inputs, outputs); - }; + const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { + OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; + + std::vector input_blobs, output_blobs; + // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays + std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; + // mapping from index in input_blobs to index in pre_temp_dst + std::unordered_map in_temp_idx_map; + INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy, outputs, req); + + std::vector tmp_req = req; + // populate input blobs and output blobs + SetupDefaultBlobsInOut(inputs, outputs, nullptr, nullptr, &tmp_req, + &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, + &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); + // setup contexts + const bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask; + // pre-fcompute fallback + CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu); + fcompute(state, opctx, input_blobs, tmp_req, output_blobs); + // post-fcompute fallback, cast to original storage type, if necessary + CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); + if (is_gpu && exec_type == ExecType::kSync + && rctx.get_stream() && !rctx.is_bulk) { + rctx.get_stream()->Wait(); + } + DerefInputOutputRelease(inputs, outputs); + }; if (exec_type == ExecType::kSubgraphExec || CheckIfSkipEngine(attrs)) { RunContext rctx{ctx, nullptr}; @@ -852,28 +828,23 @@ inline void PushOperator(const OpStatePtr& state, } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( [=](RunContext rctx) { - run(rctx, engine::CallbackOnStart(), engine::CallbackOnComplete()); - }, - ctx, - read_vars, - write_vars, - FnProperty::kNormal, - 0, - op->name.c_str()); + run(rctx, engine::CallbackOnComplete()); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, op->name.c_str()); } else { CHECK(exec_type == ExecType::kAsync); Engine::Get()->PushAsync( - run, ctx, read_vars, write_vars, FnProperty::kAsync, 0, op->name.c_str()); + run, ctx, read_vars, write_vars, FnProperty::kAsync, + 0, op->name.c_str()); } } } -inline bool CheckAndInferShape(nnvm::Graph* p_g, - mxnet::ShapeVector&& shapes, +inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, bool use_inputs, - std::pair node_range = {0, 0}, + std::pair node_range = {0, 0}, std::pair entry_range = {0, 0}, - bool* contain_unknown = nullptr) { + bool *contain_unknown = nullptr) { using namespace nnvm; if (contain_unknown != nullptr) { *contain_unknown = false; @@ -889,16 +860,13 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, for (size_t i = 0; i < shapes.size(); ++i) { if (i == entry_range.first) { i = entry_range.second; - if (i >= shapes.size()) - break; + if (i >= shapes.size()) break; } - if (shapes[i] == prev_shapes[i]) - continue; + if (shapes[i] == prev_shapes[i]) continue; match = false; break; } - if (match) - return true; + if (match) return true; } } g.attrs.erase("shape"); @@ -910,7 +878,7 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, g = exec::InferShape(std::move(g), std::move(shapes)); } else { g.attrs["shape"] = std::make_shared(std::move(shapes)); - g = exec::InferShape(std::move(g)); + g = exec::InferShape(std::move(g)); } if (contain_unknown == nullptr) { CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0U); @@ -920,16 +888,16 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, return false; } -inline bool CheckAndInferType(nnvm::Graph* p_g, - nnvm::DTypeVector&& dtypes, + +inline bool CheckAndInferType(nnvm::Graph* p_g, nnvm::DTypeVector&& dtypes, bool use_inputs, - std::pair node_range = {0, 0}, + std::pair node_range = {0, 0}, std::pair entry_range = {0, 0}) { using namespace nnvm; nnvm::Graph& g = *p_g; if (use_inputs) { - if (g.attrs.count("dtype_inputs") && g.GetAttr("dtype_inputs") == dtypes) - return true; + if (g.attrs.count("dtype_inputs") && + g.GetAttr("dtype_inputs") == dtypes) return true; } else if (g.attrs.count("dtype")) { const auto& prev_dtypes = g.GetAttr("dtype"); CHECK_EQ(prev_dtypes.size(), dtypes.size()); @@ -937,16 +905,13 @@ inline bool CheckAndInferType(nnvm::Graph* p_g, for (size_t i = 0; i < dtypes.size(); ++i) { if (i == entry_range.first) { i = entry_range.second; - if (i >= dtypes.size()) - break; + if (i >= dtypes.size()) break; } - if (dtypes[i] == prev_dtypes[i]) - continue; + if (dtypes[i] == prev_dtypes[i]) continue; match = false; break; } - if (match) - return true; + if (match) return true; } g.attrs.erase("dtype"); g.attrs.erase("dtype_inputs"); @@ -960,31 +925,28 @@ inline bool CheckAndInferType(nnvm::Graph* p_g, g = exec::InferType(std::move(g), std::move(dtypes)); } else { g.attrs["dtype"] = std::make_shared(std::move(dtypes)); - g = exec::InferType(std::move(g)); + g = exec::InferType(std::move(g)); } CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0U); return false; } -inline bool CheckAndInferStorageType(nnvm::Graph* p_g, - exec::DevMaskVector&& dev_mask, - StorageTypeVector&& storage_types, - bool use_inputs, - std::pair node_range = {0, 0}, +inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev_mask, + StorageTypeVector&& storage_types, bool use_inputs, + std::pair node_range = {0, 0}, std::pair entry_range = {0, 0}) { using namespace nnvm; nnvm::Graph& g = *p_g; - bool dev_match = - g.attrs.count("dev_mask") && g.GetAttr("dev_mask") == dev_mask; + bool dev_match = g.attrs.count("dev_mask") && + g.GetAttr("dev_mask") == dev_mask; if (!dev_match) { g.attrs["dev_mask"] = std::make_shared(std::move(dev_mask)); } if (dev_match && use_inputs) { if (g.attrs.count("storage_type_inputs") && - g.GetAttr("storage_type_inputs") == storage_types) - return true; + g.GetAttr("storage_type_inputs") == storage_types) return true; } else if (dev_match && g.attrs.count("storage_type")) { const auto& prev_storage_types = g.GetAttr("storage_type"); CHECK_EQ(prev_storage_types.size(), storage_types.size()); @@ -992,16 +954,13 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, for (size_t i = 0; i < storage_types.size(); ++i) { if (i == entry_range.first) { i = entry_range.second; - if (i >= storage_types.size()) - break; + if (i >= storage_types.size()) break; } - if (storage_types[i] == prev_storage_types[i]) - continue; + if (storage_types[i] == prev_storage_types[i]) continue; match = false; break; } - if (match) - return true; + if (match) return true; } g.attrs.erase("dispatch_mode"); g.attrs.erase("storage_type"); @@ -1013,17 +972,18 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, g = exec::InferStorageType(std::move(g), std::move(storage_types)); } else { g.attrs["storage_type"] = std::make_shared(std::move(storage_types)); - g = exec::InferStorageType(std::move(g)); + g = exec::InferStorageType(std::move(g)); } CHECK_EQ(g.GetAttr("storage_type_num_unknown_nodes"), 0U); return false; } + inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { static const auto& _copyto = Op::Get("_copyto"); - std::vector vctx(idx.num_nodes(), - Context::Create(static_cast(-1), 0)); + std::vector vctx( + idx.num_nodes(), Context::Create(static_cast(-1), 0)); // forward pass for (size_t i = 0; i < idx.num_nodes(); ++i) { if (!idx[i].source->info.empty()) { @@ -1038,8 +998,7 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { vctx[i] = vctx[idx[i].control_deps[0]]; } else { for (const auto& in : idx[i].inputs) { - if (vctx[in.node_id].dev_type == static_cast(-1)) - continue; + if (vctx[in.node_id].dev_type == static_cast(-1)) continue; vctx[i] = vctx[in.node_id]; break; } @@ -1047,12 +1006,10 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { } // backward pass for (int i = idx.num_nodes() - 1; i >= 0; --i) { - if (vctx[i].dev_type == static_cast(-1)) - continue; + if (vctx[i].dev_type == static_cast(-1)) continue; if (idx[i].source->op() == _copyto) { auto in_nid = idx[i].inputs[0].node_id; - if (vctx[in_nid].dev_type != static_cast(-1)) - continue; + if (vctx[in_nid].dev_type != static_cast(-1)) continue; CHECK_GT(idx[i].source->control_deps.size(), 0); auto fwd_nid = idx.node_id(idx[i].source->control_deps[0].get()); CHECK_EQ(idx[fwd_nid].source->op(), _copyto); @@ -1060,8 +1017,7 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { continue; } for (const auto& j : idx[i].inputs) { - if (vctx[j.node_id].dev_type != static_cast(-1)) - continue; + if (vctx[j.node_id].dev_type != static_cast(-1)) continue; vctx[j.node_id] = vctx[i]; } } @@ -1095,12 +1051,12 @@ inline MemoryPlanVector MXPlanMemory(nnvm::Graph* p_g, if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g); - const auto& dtypes = g.GetAttr("dtype"); - const auto& shapes = g.GetAttr("shape"); + const auto& dtypes = g.GetAttr("dtype"); + const auto& shapes = g.GetAttr("shape"); const auto& storage_inplace = g.GetAttr >("storage_inplace_index"); - g.attrs[storage_plan] = std::make_shared(storage_inplace); - const auto& storage_ids = g.GetAttr("storage_id"); - uint32_t entry_start = entry_range.first; + g.attrs[storage_plan] = std::make_shared(storage_inplace); + const auto& storage_ids = g.GetAttr("storage_id"); + uint32_t entry_start = entry_range.first; uint32_t entry_end = entry_range.second > entry_start ? entry_range.second : idx.num_node_entries(); MemoryPlanVector mem_plan(idx.num_node_entries()); @@ -1112,13 +1068,14 @@ inline MemoryPlanVector MXPlanMemory(nnvm::Graph* p_g, } else if (!sid_to_root.count(storage_ids[i])) { CHECK_LT(storage_inplace[i], 0); sid_to_root[storage_ids[i]] = i; - mem_plan[i] = { - storage_ids[i], i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false}; + mem_plan[i] = {storage_ids[i], i, + mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), + false}; } else { uint32_t root = sid_to_root[storage_ids[i]]; - mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0}; - mem_plan[root].size = - std::max(mem_plan[root].size, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size()); + mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0}; + mem_plan[root].size = std::max(mem_plan[root].size, + mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size()); } } @@ -1129,11 +1086,10 @@ inline std::multimap AllocateMemory( const nnvm::Graph& g, const nnvm::IndexedGraph& idx, const Context& default_ctx, - const uint32_t entry_start, - const uint32_t entry_end, + const uint32_t entry_start, const uint32_t entry_end, const MemoryPlanVector& mem_plan, const std::vector& arrays, - std::vector* array_reqs, + std::vector *array_reqs, std::multimap&& pool = std::multimap()) { using namespace nnvm; const auto& dtypes = g.GetAttr("dtype"); @@ -1152,19 +1108,18 @@ inline std::multimap AllocateMemory( continue; } data_entry_profiler_scopes[eid - entry_start] = profiler_scope; - data_entry_names[eid - entry_start] = idx[nid].source->attrs.name; + data_entry_names[eid - entry_start] = idx[nid].source->attrs.name; } } - const NDArray* pntr; + const NDArray *pntr; for (uint32_t i = entry_start; i < entry_end; ++i) { - const auto& plan = mem_plan[i]; - if (plan.storage_id == exec::kExternalStorageID) - continue; + const auto &plan = mem_plan[i]; + if (plan.storage_id == exec::kExternalStorageID) continue; CHECK(arrays[i]->is_none()); if (plan.storage_id == exec::kDynamicStorageID) { - *arrays[i] = NDArray( - static_cast(stypes[i]), shapes[i], default_ctx, true, dtypes[i]); + *arrays[i] = NDArray(static_cast(stypes[i]), + shapes[i], default_ctx, true, dtypes[i]); arrays[i]->AssignStorageInfo(data_entry_profiler_scopes[i - entry_start], data_entry_names[i - entry_start]); continue; @@ -1177,9 +1132,7 @@ inline std::multimap AllocateMemory( pool.erase(iter); } else { NDArray buff(mxnet::TShape({static_cast(plan.size)}), - default_ctx, - true, - mshadow::kUint8); + default_ctx, true, mshadow::kUint8); buff.AssignStorageInfo(data_entry_profiler_scopes[i - entry_start], data_entry_names[i - entry_start]); pntr = &new_pool.insert({plan.size, buff})->second; @@ -1196,12 +1149,13 @@ inline std::multimap AllocateMemory( return new_pool; } -inline void SetupOpExec(const nnvm::Graph& g, - size_t nid, - const std::shared_ptr& exec, - const std::vector arrays, - const std::vector array_reqs) { - const auto& idx = g.indexed_graph(); +inline void SetupOpExec( + const nnvm::Graph& g, + size_t nid, + const std::shared_ptr& exec, + const std::vector arrays, + const std::vector array_reqs) { + const auto& idx = g.indexed_graph(); const auto& inode = idx[nid]; CHECK_EQ(exec->in_array.size(), 0U); CHECK_EQ(exec->out_array.size(), 0U); @@ -1377,32 +1331,31 @@ inline void CreateEngineOpSeg(const nnvm::IndexedGraph& idx, void RunGraph(const bool retain_graph, const nnvm::IndexedGraph& idx, const std::vector& arrays, - size_t node_start, - size_t node_end, + size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, - std::vector* p_states, - const DispatchModeVector& dispatch_modes, + std::vector *p_states, + const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector* shapes = nullptr, + mxnet::ShapeVector *shapes = nullptr, const CachedOpMonCallback& callback = nullptr, - const bool monitor_all_ = false); + const bool monitor_all_ = false, + const std::vector& nleafs = std::vector()); void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, const nnvm::IndexedGraph& idx, const std::vector& arrays, - size_t node_start, - size_t node_end, + size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, - std::vector* p_states, - const DispatchModeVector& dispatch_modes, + std::vector *p_states, + const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector* shapes, + mxnet::ShapeVector *shapes, const CachedOpMonCallback& callback = nullptr, - const bool monitor_all_ = false, - const bool skip_engine = false); + const bool monitor_all_ = false, + const bool skip_engine = false); } // namespace imperative } // namespace mxnet diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 902880fb1d52..f4cf02c15633 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -513,6 +513,10 @@ void NDArray::set_fresh_out_grad(bool state) const { info.fresh_out_grad = state; } +void NDArray::copy_autograd_entry_(const NDArray* src) { + autograd_entry_ = nnvm::NodeEntry{src->autograd_entry_.node, 0, 0}; +} + #if MXNET_USE_ONEDNN == 1 bool NDArray::Chunk::IsDNNL() const { diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index c48d20479f15..a103d8e917a0 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -533,7 +533,7 @@ def test_retain_grad_drop_grad(): z.attach_grad() out_grad = nd.array([10, 10, 10, 10]) z.backward(out_grad, retain_graph=True) - + assert (u.grad == out_grad * x).asnumpy().all() assert (z.grad == out_grad).asnumpy().all() assert (x.grad == out_grad * 2 * x * y).asnumpy().all() @@ -548,39 +548,48 @@ def test_retain_grad_drop_grad(): assert u.grad is None and z.grad is None and y.grad is None assert (x.grad == out_grad * 2 * x * y).asnumpy().all() -def test_retain_grad_drop_grad_gluon(): - class CompBlock(mx.gluon.HybridBlock): +@pytest.fixture(scope="function", params=[True, False]) +def test_retain_grad_drop_grad_gluon(request): + class CompBlock(mx.HybridBlock): def __init__(self): super().__init__() - self.marked_var = None - def forward(self, a, b): - out1 = a*b - out2 = out1 * a - self.marked_var = out1 + + def forward(self, a, b, c): + out1 = self.intermediate(('out1_0', 'out1_1'), ((a+b)*c, a*b), grad_req='write') + out2 = self.intermediate('out2', out1[1] * a) return out2 + x = mx.np.array([1,2,3,4]) y = mx.np.array([5,6,7,8]) + w = mx.np.array([0.1, 0.1, 0.1, 0.1]) x.attach_grad() y.attach_grad() + w.attach_grad() block2 = CompBlock() block2.initialize() - # block2.hybridize() + param = request.param + if param: + block2.hybridize() with mx.autograd.record(): - z = block2(x, y) - u = block2.marked_var - u.attach_grad() - z.attach_grad() + z = block2(x, y, w) + + block2.attach_grad_intermediate() + u0 = block2.get_intermediate('out1_0').data() + u = block2.get_intermediate('out1_1').data() + z = block2.get_intermediate('out2').data() z.backward(retain_graph=True) assert (u.grad == x).all() + assert (u0.grad == mx.np.array([0, 0, 0, 0])).all() assert (z.grad == mx.np.array([1,1,1,1])).all() assert (x.grad == 2 * x * y).all() assert (y.grad == x*x).all() u.drop_grad() + u0.drop_grad() z.drop_grad() y.drop_grad() z.backward() - assert u.grad is None and z.grad is None and y.grad is None + assert u.grad is None and u0.grad is None and y.grad is None and z.grad is None assert (x.grad == 2 * x * y).all()