Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
retain grad hybrid
Browse files Browse the repository at this point in the history
multiple_output

Intermediate_container

retain_grad_hybrid
  • Loading branch information
KexinFeng committed Jul 7, 2022
1 parent fa9733d commit 6efb2b7
Show file tree
Hide file tree
Showing 15 changed files with 1,075 additions and 1,003 deletions.
9 changes: 9 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,15 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXAutogradDropGrads(uint32_t num_var, NDArrayHandle* var_handles);
/*!
* \brief mark nonleaf NDArrays as variables during deferredcomputation
* \param num_nleafs number of nonleaf NDArrays
* \param cnt_var count of existing marked nonleaf variables
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle *nleaf_handles,
int num_nleafs,
int cnt_var);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class Imperative {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief mark nonleaf variables during DC for computing gradients. */
void MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars);
/*! \brief unmark nonleaf variables to free the memory. */
void DropGrads(const std::vector<NDArray*>& variables);
/*! \brief compute the gradient of outputs w.r.t variables. */
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in autograd_entry_ */
void set_fresh_out_grad(bool state) const;
/*! \brief copy the autograd_entry_ from src NDArray */
void copy_autograd_entry_(const NDArray* src);
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/_ctypes/cached_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs):
if not default_device:
default_device = kwargs.pop('default_ctx', None)
out = kwargs.pop('out', None)
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
if kwargs:
raise TypeError(
"CachedOp.__call__ got unexpected keyword argument(s): " + \
Expand All @@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs):
*args,
type_id,
device_id,
*out_arg
len(out_arg),
*out_arg,
len(nleaf_vars),
*nleaf_vars
)
if out is not None:
return out
Expand Down
49 changes: 47 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import json
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
_as_list
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
profiler as _profiler, device as _device
from ..symbol.numpy import _symbol as np_symbol
Expand Down Expand Up @@ -1091,6 +1092,7 @@ def __init__(self):
self._backend_opts = {}
self._partition_if_dynamic = True
self._first_forward = True
self._nleaf_vars = OrderedDict()

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args):
args_without_none = [ele for ele in args if ele is not None]
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, name, i in self._cached_op_args]
out = self._cached_op(*cargs)
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)
Expand Down Expand Up @@ -1678,6 +1680,49 @@ def reset_ctx(self, ctx):
self.reset_device(ctx)


def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.
Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
as feature extractors. For example, you may want to extract the output
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,40 @@ def grad_req(self, req):
warnings.warn('Constant parameter "{}" does not support '
'grad_req other than "null", and new value "{}" '
'is ignored.'.format(self.name, req))

class Intermediate:
"""A Container holding marked intermediate variables of Blocks.
Parameters
----------
name : str.
Name of this parameter. It be used to retrieve the marked variables.
grad_req : {'write', 'add', 'null'}, default 'write'
Specifies how to update gradient to grad arrays.
- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
to manually call ``zero_grad()`` to clear the gradient buffer before each
iteration when using this option.
- 'null' means gradient is not requested for this parameter. gradient arrays
will not be allocated.
"""
def __init__(self, name, data=None, grad_req='write'):
self._name = name
self._data = data
self._grad_req = grad_req

def __repr__(self):
s = 'Intermediate name={name}'
return s.format(name=self._name)

def data(self):
return self._data

@property
def name(self):
return self._name

@property
def grad_req(self):
return self._grad_req
43 changes: 26 additions & 17 deletions src/api/cached_op_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,21 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
}

std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 4].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i)
ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - 4;
CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs()
<< " outputs, but " << array_size << " was given.";
for (int i = num_inputs + 4; i < array_size; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}
int num_outputs = args[num_inputs + 4];
int num_nleafs = args[num_inputs + num_outputs + 5];
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 5].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - num_nleafs - 6;
CHECK_EQ(array_size, op->num_outputs())
<< "CachedOp expects " << op->num_outputs() << " outputs, but "
<< array_size << " was given.";
for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}

int default_dev_type;
int default_dev_id;
Expand All @@ -69,10 +71,17 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
default_dev_id = ctx.dev_id;
}

// construct default context
Context ctx =
Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id);
op->Forward(op_shared, ndinputs, ndoutputs, ctx);
std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
nleafs.push_back(static_cast<mxnet::NDArray*>(args[i + num_inputs + num_outputs + 6]));
}
op->set_nleafs(nleafs);

// construct default context
Context ctx = Context::Create(static_cast<Context::DeviceType>(default_dev_type),
default_dev_id);
op->Forward(op_shared, ndinputs, ndoutputs, ctx);

if (op->num_outputs() == 1) {
*ret = ndoutputs[0];
Expand Down
12 changes: 12 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles,
*out = s;
API_END_HANDLE_ERROR(delete s;);
}

int MXNDArrayMarkDCVariables(NDArrayHandle *nleaf_handles, int num_nleafs, int cnt_var) {
API_BEGIN();
std::vector<NDArray *> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
NDArray *array = reinterpret_cast<NDArray *>(nleaf_handles[i]);
nleafs.emplace_back(array);
}
Imperative::Get()->MarkDCVariables(nleafs, cnt_var);
API_END();
}
Loading

0 comments on commit 6efb2b7

Please sign in to comment.