Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix the bug of MXEnginePushAsyncND and MXEnginePushSyncND #15751

Merged
merged 11 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2921,12 +2921,12 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
* \param wait Whether this is a WaitForVar operation.
*/
MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
bool wait DEFAULT(false));
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
bool wait DEFAULT(false));

/*!
* \brief Push a synchronous operation to the engine.
Expand All @@ -2944,11 +2944,11 @@ MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
* \param opr_name The operation name.
*/
MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));

#ifdef __cplusplus
}
Expand Down
40 changes: 20 additions & 20 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1544,18 +1544,18 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
}

int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reinterpret_cast is necessary for the cast from void** to NDArray**.

NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
Expand All @@ -1564,18 +1564,18 @@ int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
}

int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
Expand Down
117 changes: 74 additions & 43 deletions tests/cpp/engine/threaded_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,49 +257,80 @@ TEST(Engine, PushFunc) {

TEST(Engine, PushFuncND) {
auto ctx = mxnet::Context{};
mxnet::NDArray nd(ctx);

// Test #1
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
int* b = new int(101);
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);
std::vector<mxnet::NDArray*> nds;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I'm thinking if it's more appropriate to use

std::vector<mxnet::NDArray> nds;

Do we have to use std::vector<mxnet::NDArray*>?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to use std::vector<mxnet::NDArray*>, because the type of the argument of the two APIs is an array pointer of NDArray*. Besides, it avoids the potential copy of NDArray.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an argument Context in the constructor of NDArray. I do not know how to use std::vector<mxnet::NDArray>.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have to use std::vector<NDArray*> if we keep the interface to be NDArrayHandle *. However, I'm thinking whether we could directly use the std::vector<mxnet::NDArray> vec; and use nds.emplace_back(mxnet::NDArray(ctx)) or nds.push_back(std::move(temp_arr)) to fill the vector. In that case, the existing API will not be changed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds good, but I am worry that the type of arguments (const_vars_handle and mutable_vars_handle) of the two APIs MXEnginePushAsync and MXEnginePushSync is EngineVarHandle, namely void*. It casts void* to VarHandle*, namely Var** in https://github.com/apache/incubator-mxnet/blob/master/src/c_api/c_api.cc#L1475. Therefore, I don't know how to decide the type of const_nds_handle and mutable_nds_handle in MXEnginePushAsyncND and MXEnginePushSyncND.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the consistency, const_var_handle and mutable_var_handle is the pointer of an array of VarHandle, and const_nds_handle and mutable_var_handle is the pointer of an array of NDArrayHandle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now understand the logic here. To make the API consistent, I think we should also change the interface of MXEnginePushAsync and MXEnginePushSync. We should be safe to replace EngineVarHandle with VarHandle*. Am I right here ? @apeforest @yuxihu

const int num_nds = 5;
for (int i = 0; i < num_nds; ++i) {
mxnet::NDArray *pnd = new mxnet::NDArray(ctx);
nds.push_back(pnd);
}
for (int num_const_nds = 0; num_const_nds <= num_nds; ++num_const_nds) {
int num_mutable_nds = num_nds - num_const_nds;
void** const_nds_handle = num_const_nds > 0 ?
reinterpret_cast<void**>(nds.data()) : nullptr;
void** mutable_nds_handle = num_mutable_nds > 0 ?
reinterpret_cast<void**>(nds.data() + num_const_nds) : nullptr;

// Test #1
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, -1,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
int* b = new int(101);
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, -1,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, -1);
EXPECT_EQ(res, -1);
}
for (mxnet::NDArray* pnd : nds) {
delete pnd;
}
}

TEST(Engine, basics) {
Expand Down