Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
multiple_output
Browse files Browse the repository at this point in the history
Intermediate_container

retain_grad_hybrid
  • Loading branch information
KexinFeng committed Jul 6, 2022
1 parent cf153c0 commit 4b28fc8
Show file tree
Hide file tree
Showing 15 changed files with 1,302 additions and 1,258 deletions.
9 changes: 9 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,15 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
NDArrayHandle *var_handles,
uint32_t *reqs_array,
NDArrayHandle *grad_handles);
/*!
* \brief mark nonleaf NDArrays as variables during deferredcomputation
* \param num_nleafs number of nonleaf NDArrays
* \param cnt_var count of existing marked nonleaf variables
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle *nleaf_handles,
int num_nleafs,
int cnt_var);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ class Imperative {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief mark nonleaf variables during DC for computing gradients. */
void MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars);
/*! \brief unmark nonleaf variables to free the memory. */
void DropGrads(const std::vector<NDArray*>& variables);
/*! \brief compute the gradient of outputs w.r.t variables. */
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in autograd_entry_ */
void set_fresh_out_grad(bool state) const;
/*! \brief copy the autograd_entry_ from src NDArray */
void copy_autograd_entry_(const NDArray* src);
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/_ctypes/cached_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __call__(self, *args, **kwargs):
# New FFI only supports numpy ndarray
default_ctx = kwargs.pop('default_ctx', None)
out = kwargs.pop('out', None)
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
if kwargs:
raise TypeError(
"CachedOp.__call__ got unexpected keyword argument(s): " + \
Expand All @@ -91,7 +92,10 @@ def __call__(self, *args, **kwargs):
*args,
type_id,
device_id,
*out_arg
len(out_arg),
*out_arg,
len(nleaf_vars),
*nleaf_vars
)
if out is not None:
return out
Expand Down
51 changes: 48 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@
import json
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
_as_list
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
profiler as _profiler, context as _context
from ..symbol.numpy import _symbol as np_symbol
from ..symbol import Symbol, fromjson
from ..ndarray import NDArray
from .parameter import Parameter, DeferredInitializationError
from .parameter import Parameter, DeferredInitializationError, Intermediate
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
from .. import numpy_extension as _mx_npx
Expand Down Expand Up @@ -1054,6 +1055,7 @@ def __init__(self):
self._backend_opts = {}
self._partition_if_dynamic = True
self._first_forward = True
self._nleaf_vars = OrderedDict()

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -1264,7 +1266,7 @@ def _call_cached_op(self, *args):
args_without_none = [ele for ele in args if ele is not None]
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, name, i in self._cached_op_args]
out = self._cached_op(*cargs)
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)
Expand Down Expand Up @@ -1635,6 +1637,49 @@ def reset_ctx(self, ctx):
for p in params.values():
p.reset_ctx(ctx)

def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.
Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
as feature extractors. For example, you may want to extract the output
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,40 @@ def grad_req(self, req):
warnings.warn('Constant parameter "{}" does not support '
'grad_req other than "null", and new value "{}" '
'is ignored.'.format(self.name, req))

class Intermediate:
"""A Container holding marked intermediate variables of Blocks.
Parameters
----------
name : str.
Name of this parameter. It be used to retrieve the marked variables.
grad_req : {'write', 'add', 'null'}, default 'write'
Specifies how to update gradient to grad arrays.
- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
to manually call ``zero_grad()`` to clear the gradient buffer before each
iteration when using this option.
- 'null' means gradient is not requested for this parameter. gradient arrays
will not be allocated.
"""
def __init__(self, name, data=None, grad_req='write'):
self._name = name
self._data = data
self._grad_req = grad_req

def __repr__(self):
s = 'Intermediate name={name}'
return s.format(name=self._name)

def data(self):
return self._data

@property
def name(self):
return self._name

@property
def grad_req(self):
return self._grad_req
Loading

0 comments on commit 4b28fc8

Please sign in to comment.