Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed May 1, 2020
1 parent 6344393 commit 86b7c3f
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def forward(self, img, bbox):
if isinstance(self._fill, numeric_types):
dst = F.full(shape=(oh, ow, c), val=self._fill, dtype=img.dtype)
else:
fill = F.array(self._fill, dtype=img.dtype, ctx=img.context)
fill = F.array(self._fill, dtype=img.dtype, ctx=img.ctx)
if not c == fill.size:
raise ValueError("Channel and fill size mismatch, {} vs {}".format(c, fill.size))
dst = F.tile(fill.reshape((1, c)), reps=(oh * ow, 1)).reshape((oh, ow, c))
Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,14 +656,14 @@ def imrotate(src, rotation_degrees, zoom_in=False, zoom_out=False):
# when a scalar is passed we wrap it into an array
if isinstance(rotation_degrees, Number):
rotation_degrees = nd.array([rotation_degrees] * len(src),
ctx=src.context)
ctx=src.ctx)

if len(src) != len(rotation_degrees):
raise ValueError(
"The number of images must be equal to the number of rotation angles"
)

rotation_degrees = rotation_degrees.as_in_context(src.context)
rotation_degrees = rotation_degrees.as_in_context(src.ctx)
rotation_rad = np.pi * rotation_degrees / 180
# reshape the rotations angle in order to be broadcasted
# over the `src` tensor
Expand All @@ -674,10 +674,10 @@ def imrotate(src, rotation_degrees, zoom_in=False, zoom_out=False):
hscale = (float(h - 1) / 2)
wscale = (float(w - 1) / 2)
h_matrix = (
nd.repeat(nd.arange(h, ctx=src.context).astype('float32').reshape(h, 1), w, axis=1) - hscale
nd.repeat(nd.arange(h, ctx=src.ctx).astype('float32').reshape(h, 1), w, axis=1) - hscale
).expand_dims(axis=0)
w_matrix = (
nd.repeat(nd.arange(w, ctx=src.context).astype('float32').reshape(1, w), h, axis=0) - wscale
nd.repeat(nd.arange(w, ctx=src.ctx).astype('float32').reshape(1, w), h, axis=0) - wscale
).expand_dims(axis=0)
# perform rotation on the grid
c_alpha = nd.cos(rotation_rad)
Expand All @@ -689,7 +689,7 @@ def imrotate(src, rotation_degrees, zoom_in=False, zoom_out=False):
w_matrix_rot = w_matrix_rot / wscale
h_matrix_rot = h_matrix_rot / hscale

h, w = nd.array([h], ctx=src.context), nd.array([w], ctx=src.context)
h, w = nd.array([h], ctx=src.ctx), nd.array([w], ctx=src.ctx)
# compute the scale factor in case `zoom_in` or `zoom_out` are True
if zoom_in or zoom_out:
rho_corner = nd.sqrt(h * h + w * w)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=
atol = get_atol(atol)
use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray)
if not use_np_allclose:
if not (hasattr(a, 'context') and hasattr(b, 'context') and a.context == b.context and a.dtype == b.dtype):
if not (hasattr(a, 'ctx') and hasattr(b, 'ctx') and a.ctx == b.ctx and a.dtype == b.dtype):
use_np_allclose = True
if isinstance(a, mx.nd.NDArray):
a = a.asnumpy()
Expand Down
2 changes: 1 addition & 1 deletion src/io/iter_prefetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class PrefetcherIter : public IIterator<DataBatch> {
// copy data over
for (size_t i = 0; i < batch.data.size(); ++i) {
if ((*dptr)->data.at(i).shape() != batch.data[i].shape_) {
// perf warning, dynamic buffer might be slow
// TODO(zhreshold): memory pool for dynamic shaped data
(*dptr)->data.at(i).ReshapeAndAlloc(batch.data[i].shape_);
}
CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_);
Expand Down
8 changes: 3 additions & 5 deletions src/io/iter_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ class SequentialSampler : public IIterator<DataInst> {
param_.InitAllowUnknown(kwargs);
indices_.resize(param_.length);
std::iota(std::begin(indices_), std::end(indices_), 0); // fill like arange
out_.data.resize(2); // label required by DataBatch, we can use fake label here
out_.data[1] = TBlob(indices_.data(), TShape({1, }), cpu::kDevMask, 0);
out_.data.resize(1);
}

virtual void BeforeFirst(void) {
Expand Down Expand Up @@ -129,8 +128,7 @@ class RandomSampler : public IIterator<DataInst> {
mshadow::Random<cpu> *ctx_rng = ResourceManager::Get()->Request(
Context::CPU(), ResourceRequest::kRandom).get_random<cpu, real_t>(nullptr);
rng_.reset(new common::RANDOM_ENGINE(ctx_rng->GetSeed() + param_.seed));
out_.data.resize(2); // label required by DataBatch, we can use fake label here
out_.data[1] = TBlob(indices_.data(), TShape({1, }), cpu::kDevMask, 0);
out_.data.resize(1);
BeforeFirst();
}

Expand Down Expand Up @@ -164,7 +162,7 @@ class RandomSampler : public IIterator<DataInst> {
/*! \brief data for next value */
DataInst out_;
/*! \brief random generator engine */
std::unique_ptr<common::RANDOM_ENGINE> rng_;
std::unique_ptr<std::mt19937> rng_;
/*! \brief arguments */
RandomSamplerParam param_;
}; // class RandomSampler
Expand Down

0 comments on commit 86b7c3f

Please sign in to comment.