-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix the bug of MXEnginePushAsyncND
and MXEnginePushSyncND
#15751
Changes from all commits
3e8c4b5
8ec22bf
87db11d
bdc47ad
753dc2b
290f324
d34a58e
85afeaf
15c8b26
14d4eb9
a7ae4ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -257,49 +257,80 @@ TEST(Engine, PushFunc) { | |
|
||
TEST(Engine, PushFuncND) { | ||
auto ctx = mxnet::Context{}; | ||
mxnet::NDArray nd(ctx); | ||
|
||
// Test #1 | ||
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter ====="; | ||
int* a = new int(100); | ||
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #2 | ||
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter ====="; | ||
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #3 | ||
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds ====="; | ||
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0); | ||
EXPECT_EQ(res, -1); | ||
|
||
// Test #4 | ||
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds ====="; | ||
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1); | ||
EXPECT_EQ(res, -1); | ||
|
||
// Test #5 | ||
LOG(INFO) << "===== Test #5: PushSyncND param and deleter ====="; | ||
int* b = new int(101); | ||
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #6 | ||
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter ====="; | ||
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #7 | ||
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds ====="; | ||
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0); | ||
EXPECT_EQ(res, -1); | ||
|
||
// Test #8 | ||
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds ====="; | ||
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1); | ||
EXPECT_EQ(res, -1); | ||
std::vector<mxnet::NDArray*> nds; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, I'm thinking if it's more appropriate to use std::vector<mxnet::NDArray> nds; Do we have to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an argument Context in the constructor of NDArray. I do not know how to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we have to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It sounds good, but I am worry that the type of arguments ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To keep the consistency, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I now understand the logic here. To make the API consistent, I think we should also change the interface of |
||
const int num_nds = 5; | ||
for (int i = 0; i < num_nds; ++i) { | ||
mxnet::NDArray *pnd = new mxnet::NDArray(ctx); | ||
nds.push_back(pnd); | ||
} | ||
for (int num_const_nds = 0; num_const_nds <= num_nds; ++num_const_nds) { | ||
int num_mutable_nds = num_nds - num_const_nds; | ||
void** const_nds_handle = num_const_nds > 0 ? | ||
reinterpret_cast<void**>(nds.data()) : nullptr; | ||
void** mutable_nds_handle = num_mutable_nds > 0 ? | ||
reinterpret_cast<void**>(nds.data() + num_const_nds) : nullptr; | ||
|
||
// Test #1 | ||
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter ====="; | ||
int* a = new int(100); | ||
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, | ||
const_nds_handle, num_const_nds, | ||
mutable_nds_handle, num_mutable_nds); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #2 | ||
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter ====="; | ||
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, | ||
const_nds_handle, num_const_nds, | ||
mutable_nds_handle, num_mutable_nds); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #3 | ||
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds ====="; | ||
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, | ||
const_nds_handle, -1, | ||
mutable_nds_handle, num_mutable_nds); | ||
EXPECT_EQ(res, -1); | ||
|
||
// Test #4 | ||
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds ====="; | ||
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, | ||
const_nds_handle, num_const_nds, | ||
mutable_nds_handle, -1); | ||
EXPECT_EQ(res, -1); | ||
|
||
// Test #5 | ||
LOG(INFO) << "===== Test #5: PushSyncND param and deleter ====="; | ||
int* b = new int(101); | ||
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, | ||
const_nds_handle, num_const_nds, | ||
mutable_nds_handle, num_mutable_nds); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #6 | ||
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter ====="; | ||
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, | ||
const_nds_handle, num_const_nds, | ||
mutable_nds_handle, num_mutable_nds); | ||
EXPECT_EQ(res, 0); | ||
|
||
// Test #7 | ||
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds ====="; | ||
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, | ||
const_nds_handle, -1, | ||
mutable_nds_handle, num_mutable_nds); | ||
EXPECT_EQ(res, -1); | ||
|
||
// Test #8 | ||
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds ====="; | ||
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, | ||
const_nds_handle, num_const_nds, | ||
mutable_nds_handle, -1); | ||
EXPECT_EQ(res, -1); | ||
} | ||
for (mxnet::NDArray* pnd : nds) { | ||
delete pnd; | ||
} | ||
} | ||
|
||
TEST(Engine, basics) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reinterpret_cast
is necessary for the cast fromvoid**
toNDArray**
.