Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jul 7, 2022
1 parent 6efb2b7 commit 5d8f56c
Show file tree
Hide file tree
Showing 8 changed files with 1,052 additions and 963 deletions.
4 changes: 1 addition & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1289,9 +1289,7 @@ MXNET_DLL int MXAutogradDropGrads(uint32_t num_var, NDArrayHandle* var_handles);
* \param cnt_var count of existing marked nonleaf variables
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle *nleaf_handles,
int num_nleafs,
int cnt_var);
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
Expand Down
50 changes: 25 additions & 25 deletions src/api/cached_op_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
}

int num_outputs = args[num_inputs + 4];
int num_nleafs = args[num_inputs + num_outputs + 5];
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 5].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - num_nleafs - 6;
CHECK_EQ(array_size, op->num_outputs())
<< "CachedOp expects " << op->num_outputs() << " outputs, but "
<< array_size << " was given.";
for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}
int num_outputs = args[num_inputs + 4];
int num_nleafs = args[num_inputs + num_outputs + 5];
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 5].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i)
ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - num_nleafs - 6;
CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs()
<< " outputs, but " << array_size << " was given.";
for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}

int default_dev_type;
int default_dev_id;
Expand All @@ -71,17 +71,17 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
default_dev_id = ctx.dev_id;
}

std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
nleafs.push_back(static_cast<mxnet::NDArray*>(args[i + num_inputs + num_outputs + 6]));
}
op->set_nleafs(nleafs);
std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
nleafs.push_back(static_cast<mxnet::NDArray*>(args[i + num_inputs + num_outputs + 6]));
}
op->set_nleafs(nleafs);

// construct default context
Context ctx = Context::Create(static_cast<Context::DeviceType>(default_dev_type),
default_dev_id);
op->Forward(op_shared, ndinputs, ndoutputs, ctx);
// construct default context
Context ctx =
Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id);
op->Forward(op_shared, ndinputs, ndoutputs, ctx);

if (op->num_outputs() == 1) {
*ret = ndoutputs[0];
Expand Down
6 changes: 3 additions & 3 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,12 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles,
API_END_HANDLE_ERROR(delete s;);
}

int MXNDArrayMarkDCVariables(NDArrayHandle *nleaf_handles, int num_nleafs, int cnt_var) {
int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var) {
API_BEGIN();
std::vector<NDArray *> nleafs;
std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
NDArray *array = reinterpret_cast<NDArray *>(nleaf_handles[i]);
NDArray* array = reinterpret_cast<NDArray*>(nleaf_handles[i]);
nleafs.emplace_back(array);
}
Imperative::Get()->MarkDCVariables(nleafs, cnt_var);
Expand Down
Loading

0 comments on commit 5d8f56c

Please sign in to comment.