Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Mar 4, 2020
1 parent ffdfff0 commit d66ef96
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 21 deletions.
3 changes: 1 addition & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2740,8 +2740,7 @@ MXNET_DLL int MXDatasetGetLen(DatasetHandle handle,
MXNET_DLL int MXDatasetGetItems(DatasetHandle handle,
uint64_t index,
int* num_outputs,
NDArrayHandle **outputs,
NDArrayHandle *is_scalar);
NDArrayHandle **outputs);

/*!
* \brief List all the available batchify function entries
Expand Down
7 changes: 2 additions & 5 deletions python/mxnet/gluon/data/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,15 @@ def __getitem__(self, idx):
create_ndarray_fn = _np_ndarray_cls if is_np_array() else _ndarray_cls
output_vars = ctypes.POINTER(NDArrayHandle)()
num_output = ctypes.c_int(0)
is_scalars = NDArrayHandle()
check_call(_LIB.MXDatasetGetItems(self.handle,
ctypes.c_uint64(idx),
ctypes.byref(num_output),
ctypes.byref(output_vars)))
out = [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
False) for i in range(num_output.value)]
nd_isscalar = create_ndarray_fn(is_scalars).asnumpy()
for i in range(num_output.value):
if nd_isscalar[i] == 1:
assert out[i].size == 1, "is_scalar size: {}".format(out[i].size)
out[i] = out[i].asnumpy()[0]
if out[i].size == 1:
out[i] = out[i].asnumpy()
if len(out) > 1:
return tuple(out)
return out[0]
Expand Down
5 changes: 3 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1716,11 +1716,12 @@ int MXDataIterNext(DataIterHandle handle, int *out) {
int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
bool no_label = db.data.size() < 2U;
NDArray* pndarray = new NDArray();
// temp hack to make label 1D
// TODO(tianjun) make label 1D when label_width=0
mxnet::TShape shape = db.data[1].shape();
if (shape.Size() < 1) {
mxnet::TShape shape = no_label? TShape({1,}) : db.data[1].shape();
if (no_label || shape.Size() < 1) {
// it's possible that label is not available and not required
// but we need to bypass the invalid copy
*pndarray = NDArray(TShape({1}), mxnet::Context::CPU(0));
Expand Down
17 changes: 10 additions & 7 deletions src/io/batchify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class GroupBatchify : public BatchifyFunction {
CHECK_EQ(out_size, fs_.size()) << "In GroupBatchifyFunction, Elem size "
<< out_size << " and batchify function size " << fs_.size() << " must match";
outputs.resize(out_size);
std::vector<NDArray> tmp;
for (size_t i = 0; i < out_size; ++i) {
std::vector<std::vector<NDArray> > inp;
inp.reserve(inputs.size());
for (size_t j = 0; j < inputs.size(); ++j) {
std::vector<NDArray> curr({inputs[j][i]});
inp.emplace_back(curr);
}
std::vector<NDArray> tmp;
if (!fs_[i]->Batchify(inp, tmp)) return false;
outputs[i] = tmp[0];
}
Expand Down Expand Up @@ -149,13 +149,16 @@ class StackBatchify : public BatchifyFunction {
MSHADOW_TYPE_SWITCH_WITH_BOOL(dtype, DType, {
_Pragma("omp parallel for num_threads(bs)")
for (size_t j = 0; j < bs; ++j) {
// inputs[j][i].WaitToRead();
DType *ptr = outputs[i].data().dptr<DType>();
auto asize = ashape.Size();
RunContext rctx{outputs[i].ctx(), nullptr, nullptr, false};
auto dst = TBlob(ptr + asize * j, inputs[j][i].data().shape_, cpu::kDevMask, dtype, 0);
mxnet::ndarray::Copy<cpu, cpu>(inputs[j][i].data(), &dst, Context::CPU(), Context::CPU(), rctx);
omp_exc_.Run([&] {
// inputs[j][i].WaitToRead();
DType *ptr = outputs[i].data().dptr<DType>();
auto asize = ashape.Size();
RunContext rctx{outputs[i].ctx(), nullptr, nullptr, false};
auto dst = TBlob(ptr + asize * j, inputs[j][i].data().shape_, cpu::kDevMask, dtype, 0);
mxnet::ndarray::Copy<cpu, cpu>(inputs[j][i].data(), &dst, Context::CPU(), Context::CPU(), rctx);
});
}
omp_exc_.Rethrow();
})
}
return true;
Expand Down
8 changes: 5 additions & 3 deletions src/io/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,22 @@ class ImageRecordFileDataset : public Dataset {
size -= sizeof(header);
s += sizeof(header);
NDArray label = NDArray(Context::CPU(), mshadow::default_type_flag);
TBlob dst = label.data();
RunContext rctx{Context::CPU(), nullptr, nullptr, false};
if (header.flag > 0) {
auto label_shape = header.flag <= 1 ? TShape(0, 1) : TShape({header.flag});
label.ReshapeAndAlloc(label_shape);
TBlob dst = label.data();
mxnet::ndarray::Copy<cpu, cpu>(TBlob((void*)s, label.shape(), cpu::kDevMask, label.dtype(), 0),
&dst, Context::CPU(), Context::CPU(), rctx);
s += sizeof(float) * header.flag;
size -= sizeof(float) * header.flag;
} else {
// label is a scalar with ndim() == 0
label.ReshapeAndAlloc(TShape(0, 1));
mxnet::ndarray::Copy<cpu, cpu>(TBlob((void*)(&header.label), label.shape(), cpu::kDevMask, label.dtype(), 0),
&dst, Context::CPU(), Context::CPU(), rctx);
TBlob dst = label.data();
*(dst.dptr<float>()) = header.label;
// mxnet::ndarray::Copy<cpu, cpu>(TBlob((void*)(&header.label), label.shape(), cpu::kDevMask, label.dtype(), 0),
// &dst, Context::CPU(), Context::CPU(), rctx);
}
ret.resize(2);
ret[1] = label;
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ def test_list_dataset():
pass


class Dataset(gluon.data.Dataset):
class _Dataset(gluon.data.Dataset):
def __len__(self):
return 100
def __getitem__(self, key):
return mx.nd.full((10,), key)

@with_seed()
def test_multi_worker():
data = Dataset()
data = _Dataset()
for thread_pool in [True, False]:
loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5, thread_pool=thread_pool)
for i, batch in enumerate(loader):
Expand Down Expand Up @@ -548,6 +548,7 @@ def test_batchify_group():
e = bf_handle([a, b, c])
assert d[0].shape == e[0].shape
assert d[1].shape == e[1].shape
print(d[0].asnumpy(), ',', e[0].asnumpy(), ',', e[1].asnumpy())
assert mx.test_utils.almost_equal(d[0].asnumpy(), e[0].asnumpy())
assert mx.test_utils.almost_equal(d[1].asnumpy(), e[1].asnumpy())
assert mx.test_utils.almost_equal(d[0].asnumpy(), np.stack((a[0], b[0], c[0])))
Expand Down

0 comments on commit d66ef96

Please sign in to comment.