Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

More extensions fixes #19393

Merged
merged 7 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,25 +912,25 @@ class Registry {

/*! \brief declare a variable with custom name */
#define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _
#define MX_REGISTER_DEF_(Name) CustomOp MX_REGISTER_NAME_(Name)
#define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name)

#define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _
#define MX_REGISTER_PROP_DEF_(Name) CustomPartitioner MX_REGISTER_PROP_NAME_(Name)
#define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name)

#define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _
#define MX_REGISTER_PASS_DEF_(Name) CustomPass MX_REGISTER_PASS_NAME_(Name)
#define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name)

/*! \brief assign a var to a value */
#define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \
Registry<CustomOp>::get()->add(MX_TOSTRING(Name))
mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name))

#define REGISTER_PARTITIONER(Name) \
MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \
Registry<CustomPartitioner>::get()->add(MX_TOSTRING(Name))
mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name))

#define REGISTER_PASS(Name) \
MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
Registry<CustomPass>::get()->add(MX_TOSTRING(Name))
mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name))

/* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */

Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def _build_cache(self, *args):
'added to the parameter dicts.\n'
'Please check the backend.')

param = Parameter(name)
param = Parameter(name, dtype=param_data.dtype)
param._var_name = name
serialization_name = name # HybridBlock.export
param._load_init(param_data, args[0].context)
Expand Down
9 changes: 6 additions & 3 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,13 @@ def _reduce(self):
ctx = context.cpu()
if self._stype == 'default':
block = self.list_data()
if is_np_array():
data = sum([w.copyto(ctx) for w in block]) / len(block)
if len(block) > 1:
if is_np_array():
data = sum([w.copyto(ctx) for w in block]) / len(block)
else:
data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block)
else:
data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block)
data = self.data().copyto(ctx)
else:
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
Expand Down
26 changes: 21 additions & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,13 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
auto in_first = in_shape->begin();
auto in_last = in_first + in_shape->size() - extra_inputs;
mxnet::ShapeVector *sg_in_shapes = new mxnet::ShapeVector(in_first, in_last);
return mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape);
bool res = mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape);

// assign modified input shapes to ShapeVector
for (unsigned i = 0; i < sg_in_shapes->size(); ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, i, sg_in_shapes->at(i));
}
return res;
};

// lambda function to call infer type
Expand Down Expand Up @@ -934,7 +940,12 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
auto in_last = in_first + in_type->size() - extra_inputs;
std::vector<int> *sg_in_types = new std::vector<int>(in_first, in_last);

return mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type);
bool res = mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type);
// copy and assign modified input types
for (size_t i = 0; i < sg_in_types->size(); i++) {
TYPE_ASSIGN_CHECK(*in_type, i, sg_in_types->at(i));
}
return res;
};

// lambda function to convert from external mutate_inputs to internal MXNet types
Expand Down Expand Up @@ -1034,8 +1045,13 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
auto in_last = in_first + in_stypes->size() - extra_inputs;
std::vector<int> *sg_in_stypes = new std::vector<int>(in_first, in_last);

return mxnet::op::DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
sg_in_stypes, out_stypes);
bool res = mxnet::op::DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
sg_in_stypes, out_stypes);
// copy and assign modified input storage types
for (size_t i = 0; i < sg_in_stypes->size(); i++) {
STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, sg_in_stypes->at(i));
}
return res;
};

// FGradient register lambda
Expand Down Expand Up @@ -1417,7 +1433,7 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize,
// this temp workspace holds memory allocated by custom library via OpResource
auto ndarray_alloc = [&](const mxnet::TShape &shape, Context ctx, int dtype,
std::string name, bool isArg) {
NDArray* arr = new NDArray(shape, ctx, dtype);
NDArray* arr = new NDArray(shape, ctx, false, dtype);
if (isArg) {
new_args.push_back(arr);
new_arg_names.push_back(name);
Expand Down
4 changes: 4 additions & 0 deletions tools/pip/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@
shutil.copytree(os.path.join(CURRENT_DIR, 'mxnet-build/3rdparty/tvm/nnvm/include/nnvm'),
os.path.join(CURRENT_DIR, 'mxnet/include/nnvm'))

# copy cc file for mxnet extensions
shutil.copy(os.path.join(CURRENT_DIR, 'mxnet-build/src/lib_api.cc'),
os.path.join(CURRENT_DIR, 'mxnet/src'))

package_name = 'mxnet'

variant = os.environ['mxnet_variant'].upper()
Expand Down