Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add better partial args/aux handling in symbol optimize_for #18350

Merged
merged 4 commits into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,25 @@ MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle
* \param num_options number of key value pairs
* \param keys keys for options
* \param vals values corresponding to keys
* \param num_input_shapes number of input shapes
* \param input_shape_names names of the input shapes
* \param input_shape_data pointer to the contiguous data shapes
* \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape
* is calculate as input_shape_idx[i+1] - input_shape_idx[i]
* \param num_input_dtypes number of input data types
* \param input_dtype_names array of names of the input data types
* \param input_dtypes array of values of the input data types
* \param num_input_stypesnumber of input storage types
* \param input_stype_names array of names of the input storage types
* \param input_stypes array of values of input storage types
* \param skip_infer if the optimization should skip the attribute inferences
* (to use if the backend does not require shape inference)
* \param new_args_cnt pointer a number to store the number of new args
* \param new_args_handle pointer on array to store the new args handles
* \param new_arg_names_handle pointer on array to store the new args names
* \param new_aux_cnt pointer a number to store the number of new aux
* \param new_aux_handle pointer on array to store the new aux handles
* \param new_aux_names_handle pointer on array to store the new aux names
*/
MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const char* backend_name,
Expand All @@ -2247,6 +2266,17 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add those new parameters to the docstring

int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand Down
98 changes: 92 additions & 6 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,8 @@ def _gen_atomic_symbol(self):


# pylint: disable=too-many-locals
def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
def optimize_for(self, backend, args=None, aux=None, ctx=None,
shape_dict=None, type_dict=None, stype_dict=None, skip_infer=False, **kwargs):
"""Partitions current symbol and optimizes it for a given backend,
returns new partitioned symbol.

Expand All @@ -1457,19 +1458,33 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):

args : dict of str to NDArray, optional
Input arguments to the symbol, required to infer shapes/types before partitioning

- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.
to the corresponding `NDArray`. Non defined arguments' `NDArray`s don't have to be
specified in the dict.

aux : dict of str to NDArray, optional
Input auxiliary arguments to the symbol

- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.

ctx : Context, optional
Device context, used to infer stypes

shape_dict : Dict of str->tuple, optional
Input shape dictionary.
Used iff input NDArray is not in `args`.

type_dict : Dict of str->numpy.dtype, optional
Input type dictionary.
Used iff input NDArray is not in `args`.

stype_dict : Dict of str->str, optional
Input storage type dictionary.
Used iff input NDArray is not in `args`.

skip_infer : bool, optional
If True, the optimization skips the shape, type and storage type inference pass.

kwargs : optional arguments
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`

Expand All @@ -1488,18 +1503,78 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
args_handle = c_array(NDArrayHandle, [])
else:
args_handle, args_ = self._get_ndarray_inputs('args', args,
self.list_arguments(), False)
self.list_arguments(), True)

if aux is None or len(aux) == 0:
aux_ = []
aux_handle = c_array(NDArrayHandle, [])
else:
aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux,
self.list_auxiliary_states(), False)
self.list_auxiliary_states(), True)
if ctx is None:
ctx = current_context()
assert isinstance(ctx, Context)


# parse input data shape dict
num_input_shapes = 0
input_shape_names = ctypes.POINTER(ctypes.c_char_p)()
input_shape_data = ctypes.POINTER(mx_int64)()
input_shape_idx = ctypes.POINTER(mx_uint)()
if shape_dict is not None:
input_shape_names = []
input_shape_data = []
input_shape_idx = [0]
for k, v in shape_dict.items():
if isinstance(v, (tuple, list)):
input_shape_names.append(k)
input_shape_data.extend(v)
input_shape_idx.append(len(input_shape_data))
else:
raise ValueError(str(v) + " has to be a tuple or list.")
num_input_shapes = mx_uint(len(input_shape_names))
input_shape_names = c_str_array(input_shape_names)
input_shape_data = c_array_buf(mx_int64, array('q', input_shape_data))
input_shape_idx = c_array_buf(mx_uint, array('i', input_shape_idx))

# parse input data types dict
num_input_types = 0
input_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names
input_type_data = ctypes.POINTER(mx_uint)() # provided types
if type_dict is not None:
input_type_names = []
input_type_data = []
for k, v in type_dict.items():
v = _numpy.dtype(v).type
if v in _DTYPE_NP_TO_MX:
input_type_names.append(k)
input_type_data.append(_DTYPE_NP_TO_MX[v])
else:
raise ValueError(str(v) + " is not a MXNet type.")

num_input_types = mx_uint(len(input_type_names))
input_type_names = c_str_array(input_type_names)
input_type_data = c_array_buf(ctypes.c_int, array('i', input_type_data))

# parse input data storage types dict
num_input_stypes = 0
# provided storage type argument names
input_stype_names = ctypes.POINTER(ctypes.c_char_p)()
input_stype_data = ctypes.POINTER(mx_uint)() # provided storage types
if stype_dict is not None:
input_stype_names = []
input_stype_data = []
for k, v in stype_dict.items():
if v in _STORAGE_TYPE_STR_TO_ID:
input_stype_names.append(k)
input_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v])
else:
raise ValueError(str(v) + " is not a MXNet storage type.")

num_input_stypes = mx_uint(len(input_stype_names))
input_stype_names = c_str_array(input_stype_names)
input_stype_data = c_array_buf(ctypes.c_int, array('i', input_stype_data))

new_args_size = ctypes.c_uint()
new_arg_names = ctypes.POINTER(ctypes.c_char_p)()
new_args_handle = ctypes.POINTER(NDArrayHandle)()
Expand All @@ -1523,6 +1598,17 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
mx_uint(len(key_list)),
c_str_array(key_list),
c_str_array(val_list),
num_input_shapes,
input_shape_names,
input_shape_data,
input_shape_idx,
num_input_types,
input_type_names,
input_type_data,
num_input_stypes,
input_stype_names,
input_stype_data,
ctypes.c_bool(skip_infer),
ctypes.byref(new_args_size),
ctypes.byref(new_args_handle),
ctypes.byref(new_arg_names),
Expand Down
120 changes: 82 additions & 38 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
const mx_uint num_options,
const char** keys,
const char** vals,
const uint32_t num_input_shapes,
const char** input_shape_names,
const int64_t* input_shape_data,
const uint32_t* input_shape_idx,
const uint32_t num_input_dtypes,
const char** input_dtype_names,
const int* input_dtypes,
const uint32_t num_input_stypes,
const char** input_stype_names,
const int* input_stypes,
bool skip_infer,
int* new_args_cnt,
NDArrayHandle** new_args_handle,
char*** new_arg_names_handle,
Expand All @@ -1383,47 +1394,80 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
const auto &in_arg = *(in_aux_ptr[aux_top++]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for";
const auto &in_arg = *(in_args_ptr[args_top++]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);

// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (uint32_t i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
mxnet::TShape(input_shape_data + input_shape_idx[i],
input_shape_data + input_shape_idx[i+1]));
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (uint32_t i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (uint32_t i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}
}

g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
}
Kh4L marked this conversation as resolved.
Show resolved Hide resolved
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second;
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second;
}
}
args_top++;
}
}

// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
common::HandleInferTypeError(num_forward_inputs, indexed_graph,
Kh4L marked this conversation as resolved.
Show resolved Hide resolved
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
g.GetAttr<nnvm::DTypeVector>("dtype"));
}
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<StorageTypeVector>("storage_type"));
g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));

// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
Expand Down
Loading