Skip to content

Commit

Permalink
Fix the bug of MXEnginePushAsyncND and MXEnginePushSyncND (apache…
Browse files Browse the repository at this point in the history
…#15751)

* fix push sync nd api

* align code

* update test for syncnd

* fix bug in tests/cpp/engine/threaded_engine_test

* add more testcases for MXEnginePushSyncND and MXEnginePushAsyncND

* fix test

* fix

* fix

* lint

* ci

* retrigger CI
  • Loading branch information
wkcn committed Aug 8, 2019
1 parent 0a3413f commit cd42659
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 74 deletions.
22 changes: 11 additions & 11 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2863,12 +2863,12 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
* \param wait Whether this is a WaitForVar operation.
*/
MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
bool wait DEFAULT(false));
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
bool wait DEFAULT(false));

/*!
* \brief Push a synchronous operation to the engine.
Expand All @@ -2886,11 +2886,11 @@ MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
* \param opr_name The operation name.
*/
MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));

#ifdef __cplusplus
}
Expand Down
40 changes: 20 additions & 20 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1535,18 +1535,18 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
}

int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
Expand All @@ -1555,18 +1555,18 @@ int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
}

int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
Expand Down
117 changes: 74 additions & 43 deletions tests/cpp/engine/threaded_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,49 +257,80 @@ TEST(Engine, PushFunc) {

TEST(Engine, PushFuncND) {
auto ctx = mxnet::Context{};
mxnet::NDArray nd(ctx);

// Test #1
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
int* b = new int(101);
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);
std::vector<mxnet::NDArray*> nds;
const int num_nds = 5;
for (int i = 0; i < num_nds; ++i) {
mxnet::NDArray *pnd = new mxnet::NDArray(ctx);
nds.push_back(pnd);
}
for (int num_const_nds = 0; num_const_nds <= num_nds; ++num_const_nds) {
int num_mutable_nds = num_nds - num_const_nds;
void** const_nds_handle = num_const_nds > 0 ?
reinterpret_cast<void**>(nds.data()) : nullptr;
void** mutable_nds_handle = num_mutable_nds > 0 ?
reinterpret_cast<void**>(nds.data() + num_const_nds) : nullptr;

// Test #1
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, -1,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
int* b = new int(101);
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, -1,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, -1);
EXPECT_EQ(res, -1);
}
for (mxnet::NDArray* pnd : nds) {
delete pnd;
}
}

TEST(Engine, basics) {
Expand Down

0 comments on commit cd42659

Please sign in to comment.