From 0f79fca7f34a4d6c98c01afdb6a8b687276a006f Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Tue, 3 Mar 2020 18:35:28 -0800 Subject: [PATCH] adding augmentations into vision.transforms --- python/mxnet/gluon/data/vision/transforms.py | 50 +++++++++++++++++++ .../python/unittest/test_gluon_data_vision.py | 13 +++++ 2 files changed, 63 insertions(+) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 8a2ca2cf1068..d28575c6522c 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -696,3 +696,53 @@ def forward(self, x): return x x = self.transforms(x) return x + + +class HybridRandomApply(HybridSequential): + """Apply a list of transformations randomly given probability + + Parameters + ---------- + transforms + List of transformations which must be HybridBlocks. + p : float + Probability of applying the transformations. + + + Inputs: + - **data**: input tensor. + + Outputs: + - **out**: transformed image. + """ + + def __init__(self, transforms, p=0.5): + super(HybridRandomApply, self).__init__() + assert isinstance(transforms, HybridBlock) + self.transforms = transforms + self.p = p + + def hybrid_forward(self, F, x): + cond = self.p < F.random.uniform(low=0, high=1, shape=1) + return F.contrib.cond(cond, x, self.transforms(x)) + + +class RandomGray(HybridBlock): + """Randomly convert to gray image. + + Parameters + ---------- + p : float + Probability to convert to grayscale + """ + def __init__(self, p=0.5): + super(RandomGray, self).__init__() + self.mat = self.params.get_constant('gray_mat', value=[[0.21, 0.21, 0.21], + [0.72, 0.72, 0.72], + [0.07, 0.07, 0.07]]) + self.mat.initialize() + self.p = p + + def hybrid_forward(self, F, x, mat): + cond = self.p < F.random.uniform(low=0, high=1, shape=1) + return F.contrib.cond(cond, lambda: x, lambda: F.dot(F.cast(x, dtype='float32'), mat)) \ No newline at end of file diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 747cd83e3038..a9891e84214d 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -278,6 +278,19 @@ def test_random_transforms(): num_apply += 1 assert_almost_equal(num_apply/float(iteration), 0.5, 0.1) +@with_seed() +def test_random_gray(): + from mxnet.gluon.data.vision import transforms + + transform = transforms.RandomGray(0.5) + img = mx.nd.ones((4, 4, 3), dtype='uint8') + iteration = 1000 + num_apply = 0 + for _ in range(iteration): + out = transform(img) + if out.dtype != np.uint8: + num_apply += 1 + assert_almost_equal(num_apply/float(iteration), 0.5, 0.1) if __name__ == '__main__': import nose