Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Graph Partition API #15886

Merged
merged 24 commits into from
Sep 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
837edd6
API to trigger partitioning
mseth10 Aug 14, 2019
868599c
pre- and post-partition functions for subgraph property
mseth10 Aug 15, 2019
eb7c7c7
adding infer shape type before partition
mseth10 Aug 16, 2019
ea0514f
modifying pre/post-partition declaration
mseth10 Aug 16, 2019
1eddd9b
adding support for infer shape type before partition
mseth10 Aug 16, 2019
32142bf
passing kwargs down to pre/post partition funcitons
mseth10 Aug 16, 2019
c581560
move InferForwardAttrs to common/
mseth10 Aug 17, 2019
cf885d4
Addressing github comments
mseth10 Aug 19, 2019
03d1f09
refactoring to enable infer shape/type without storage type
mseth10 Aug 20, 2019
ff30b53
check if subgraph rejected by subgraph property
mseth10 Aug 20, 2019
8ec7276
adding description
mseth10 Aug 21, 2019
8172101
setting graph attribute context from args
mseth10 Aug 21, 2019
bd567a3
adding unit test for optimize_for with default backend
mseth10 Aug 21, 2019
7a095ad
fixing args access
mseth10 Aug 23, 2019
323e53b
removing options_map from PostPartition
mseth10 Aug 23, 2019
e5b68a9
addressing PR comment
mseth10 Aug 23, 2019
338f063
adding logs about status of subgraph node creation
mseth10 Aug 26, 2019
a5d409f
allowing partial infer shapes
mseth10 Aug 27, 2019
663210f
added context argument back to optimize_for and removed args context …
Aug 29, 2019
cb1dc28
fixed spacing and dev_type
Aug 29, 2019
8fba76c
fixing lint
mseth10 Aug 29, 2019
4f00222
reorganized args list to optimize_for
Aug 29, 2019
e8428ef
fixing spacing
mseth10 Aug 29, 2019
da8f6bf
dereferencing dev_type
mseth10 Aug 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,27 @@ MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
* \param ret_sym_handle returned atomic symbol
*/
MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle);
/*!
* \brief Partitions symbol for given backend, potentially creating subgraphs
* \param sym_handle symbol to be partitioned
* \param dev_type context device type
* \param backend_name backend name
* \param ret_sym_handle partitioned symbol returned
* \param len number of args
* \param in_args_handle args array
* \param num_options number of key value pairs
* \param keys keys for options
* \param vals values corresponding to keys
*/
MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
const char* backend_name,
const int dev_type,
SymbolHandle* ret_sym_handle,
const mx_uint len,
NDArrayHandle* in_args_handle,
const mx_uint num_options,
const char** keys,
const char** vals);


//--------------------------------------------
Expand Down
58 changes: 58 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,64 @@ def _gen_atomic_symbol(self):
return Symbol(handle)


def optimize_for(self, backend, args=None, ctx=None, **kwargs):
"""Partitions current symbol and optimizes it for a given backend,
returns new partitioned symbol.

Parameters
----------
backend : str
The name of backend, as registered in `SubgraphBackendRegistry`

args : list of NDArray or dict of str to NDArray, optional
Input arguments to the symbol, required to infer shapes/types before partitioning

- If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.

ctx : Context, optional
Device context, used to infer stypes

kwargs : optional arguments
Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`

Returns
-------
out : SymbolHandle
The created symbol for target backend.
"""
out = SymbolHandle()
assert isinstance(backend, str)

if args is None:
args = []
args_handle = c_array(NDArrayHandle, [])
else:
listed_arguments = self.list_arguments()
args_handle, args = self._get_ndarray_inputs('args', args, listed_arguments, False)

if ctx is None:
ctx = current_context()
assert isinstance(ctx, Context)

key_list = []
val_list = []
for key, val in kwargs.items():
key_list.append(key)
val_list.append(str(val))
check_call(_LIB.MXOptimizeForBackend(self.handle,
c_str(backend),
ctypes.c_int(ctx.device_typeid),
ctypes.byref(out),
mx_uint(len(args)),
args_handle,
mx_uint(len(key_list)),
c_str_array(key_list),
c_str_array(val_list)))
return Symbol(out)


# pylint: disable=too-many-locals
def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
group2ctx=None, shared_arg_names=None, shared_exec=None,
Expand Down
64 changes: 64 additions & 0 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "nnvm/pass_functions.h"
#include "nnvm/symbolic.h"
#include "./c_api_common.h"
#include "../common/exec_utils.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
#include "../operator/subgraph/subgraph_property.h"
Expand Down Expand Up @@ -1214,3 +1215,66 @@ int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) {
*out = out_sym;
API_END_HANDLE_ERROR(delete out_sym);
}

int MXOptimizeForBackend(SymbolHandle sym_handle,
const char* backend_name,
const int dev_type,
SymbolHandle* ret_sym_handle,
const mx_uint len,
NDArrayHandle* in_args_handle,
const mx_uint num_options,
const char** keys,
const char** vals) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
if (len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(len);
nnvm::DTypeVector arg_dtypes(len);
StorageTypeVector arg_stypes(len);
for (mx_uint i = 0; i < len; i++) {
const auto &in_arg = *(in_args_ptr[i]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
const auto& indexed_graph = g.indexed_graph();
const auto num_forward_inputs = indexed_graph.input_nodes().size();
g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
common::HandleInferTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<nnvm::DTypeVector>("dtype"));
}
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
g.GetAttr<StorageTypeVector>("storage_type"));
}
}
std::vector<std::pair<std::string, std::string>> options_map;
for (mx_uint i = 0; i < num_options; ++i) {
options_map.emplace_back(keys[i], vals[i]);
}
const auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
property->PrePartition(g, options_map);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = ApplyPass(std::move(g), "BuildSubgraph");
g.attrs.erase("subgraph_property");
property->PostPartition(g);
}
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
47 changes: 27 additions & 20 deletions src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,30 +572,37 @@ void CreateSubgraphNode(nnvm::Graph* g,
}
const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_selector, subgraph_id);

// Connect the external nodes to the subgraph node.
subg_prop->ConnectSubgraphOutputs(n, &output_entries);
subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries);

const auto& indexed_graph = g->indexed_graph();
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto& e = n->inputs[i];
// update entry_top_order_map with newly created orig_input_entries
auto it = entry_top_order_map->find(input_entries[i]);
CHECK(it != entry_top_order_map->end());
entry_top_order_map->emplace(&e, it->second);
// update input entries' source simple nodes' outputs map
nnvm::Node* node = e.node.get();
if (indexed_graph.exist(node)) {
const auto nid = indexed_graph.node_id(node);
BiDirectedNode* sn = simple_nodes[nid].get();
for (BiDirectedNode* dest_node : subgraph_nodes) {
sn->outputs.erase(dest_node->node);
// CreateSubgraphNode returns NULL if subgraph property determines that subgraph is sub-optimal
// In that case, subgraph node is not created and graph is not modified
if (n) {
// Connect the external nodes to the subgraph node.
subg_prop->ConnectSubgraphOutputs(n, &output_entries);
subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries);

const auto& indexed_graph = g->indexed_graph();
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto& e = n->inputs[i];
// update entry_top_order_map with newly created orig_input_entries
auto it = entry_top_order_map->find(input_entries[i]);
CHECK(it != entry_top_order_map->end());
entry_top_order_map->emplace(&e, it->second);
// update input entries' source simple nodes' outputs map
nnvm::Node* node = e.node.get();
if (indexed_graph.exist(node)) {
const auto nid = indexed_graph.node_id(node);
BiDirectedNode* sn = simple_nodes[nid].get();
for (BiDirectedNode* dest_node : subgraph_nodes) {
sn->outputs.erase(dest_node->node);
}
sn->outputs[n.get()].push_back(i);
}
sn->outputs[n.get()].push_back(i);
}
}
#if DEBUG_SUBGRAPH
if (n)
LOG(INFO) << "Subgraph node created and output_entries updated.";
else
LOG(INFO) << "Subgraph node not created, output_entries not updated.";
PrintNodeEntries(output_entries);
#endif
}
Expand Down
6 changes: 6 additions & 0 deletions src/operator/subgraph/subgraph_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <unordered_map>
#include <vector>
#include <string>
#include <utility>

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -221,6 +222,11 @@ class SubgraphProperty {
return nullptr;
}

virtual void PrePartition(const nnvm::Graph& g,
const std::vector<std::pair<std::string, std::string>>& options_map) {}

virtual void PostPartition(const nnvm::Graph& g) {}

virtual SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const {
auto v1_ptr = CreateSubgraphSelector();
return std::make_shared<SubgraphSelectorV2Bridge>(v1_ptr);
Expand Down
126 changes: 126 additions & 0 deletions tests/python/unittest/test_subgraph_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,137 @@ def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
for i in range(len(outputs1)):
assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))

def set_random_inputs(exe1, input_names):
"""Sets random values to exe1's args and auxs"""
for name in input_names:
if name in exe1.arg_dict:
exe1.arg_dict[name][:] = mx.nd.random.uniform(shape=exe1.arg_dict[name].shape)
else:
assert name in exe1.aux_dict
exe1.aux_dict[name][:] = mx.nd.random.uniform(shape=exe1.aux_dict[name].shape)

def copy_inputs_between_executors(exe1, exe2, input_names):
"""Copies values of args and auxs from exe1 to exe2"""
for name in input_names:
if name in exe2.arg_dict:
exe2.arg_dict[name][:] = exe1.arg_dict[name]
else:
assert name in exe2.aux_dict
exe2.aux_dict[name][:] = exe1.aux_dict[name]

def _check_subgraph_exe5(sym, subgraph_backend, op_names):
"""Call optimize_for to trigger graph partitioning without infer shapes/types before,
then simple_bind and compare results of the partitioned sym and the original sym."""
# simple_bind
exe1 = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
input_names = sym.list_inputs()
set_random_inputs(exe1, input_names)
exe1.forward()

# partition before simple_bind
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend)
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))

exe2 = part_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
copy_inputs_between_executors(exe1, exe2, input_names)
exe2.forward()

# compare outputs
outputs1 = exe1.outputs
outputs2 = exe2.outputs
assert len(outputs1) == len(outputs2)
for i in range(len(outputs1)):
assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))

def _check_subgraph_exe6(sym, subgraph_backend, op_names):
"""Call optimize_for to trigger graph partitioning without infer shapes/types before,
then simple_bind and compare results of the partitioned sym and the original sym."""
# simple_bind
exe1 = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
input_names = sym.list_inputs()
set_random_inputs(exe1, input_names)
exe1.forward()

# infer shape/type before partition before simple_bind
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend, exe1.arg_dict)
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))

exe2 = part_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
copy_inputs_between_executors(exe1, exe2, input_names)
exe2.forward()

# compare outputs
outputs1 = exe1.outputs
outputs2 = exe2.outputs
assert len(outputs1) == len(outputs2)
for i in range(len(outputs1)):
assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))

def _check_subgraph_exe7(sym, subgraph_backend, op_names):
"""Call optimize_for to trigger graph partitioning without infer shapes/types before,
then bind and compare results of the partitioned sym and the original sym."""
# bind
arg_shapes, _, aux_shapes = sym.infer_shape()
arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe1.forward()

# partition before bind
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend)
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))

exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe2.forward()

# compare outputs
outputs1 = exe1.outputs
outputs2 = exe2.outputs
assert len(outputs1) == len(outputs2)
for i in range(len(outputs1)):
assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))

def _check_subgraph_exe8(sym, subgraph_backend, op_names):
"""Call optimize_for to infer shapes, types and dtypes followed by graph partitioning,
then bind and compare results of the partitioned sym and the original sym."""
# bind
arg_shapes, _, aux_shapes = sym.infer_shape()
arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe1.forward()

# infer shape/type before partition before bind
check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend, arg_array)
check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))

exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe2.forward()

# compare outputs
outputs1 = exe1.outputs
outputs2 = exe2.outputs
assert len(outputs1) == len(outputs2)
for i in range(len(outputs1)):
assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))

def check_subgraph_exe(sym, subgraph_backend, op_names):
_check_subgraph_exe1(sym, subgraph_backend, op_names)
_check_subgraph_exe2(sym, subgraph_backend, op_names)
_check_subgraph_exe3(sym, subgraph_backend, op_names)
_check_subgraph_exe4(sym, subgraph_backend, op_names)
_check_subgraph_exe5(sym, subgraph_backend, op_names)
_check_subgraph_exe6(sym, subgraph_backend, op_names)
_check_subgraph_exe7(sym, subgraph_backend, op_names)
_check_subgraph_exe8(sym, subgraph_backend, op_names)

def test_network_structure_1(subgraph_backend):
data1 = mx.sym.var('data1', shape=(2, 3, 10, 10))
Expand Down