Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding augmentations into vision.transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Mar 4, 2020
1 parent d66ef96 commit 0f79fca
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
50 changes: 50 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,53 @@ def forward(self, x):
return x
x = self.transforms(x)
return x


class HybridRandomApply(HybridSequential):
"""Apply a list of transformations randomly given probability
Parameters
----------
transforms
List of transformations which must be HybridBlocks.
p : float
Probability of applying the transformations.
Inputs:
- **data**: input tensor.
Outputs:
- **out**: transformed image.
"""

def __init__(self, transforms, p=0.5):
super(HybridRandomApply, self).__init__()
assert isinstance(transforms, HybridBlock)
self.transforms = transforms
self.p = p

def hybrid_forward(self, F, x):
cond = self.p < F.random.uniform(low=0, high=1, shape=1)
return F.contrib.cond(cond, x, self.transforms(x))


class RandomGray(HybridBlock):
"""Randomly convert to gray image.
Parameters
----------
p : float
Probability to convert to grayscale
"""
def __init__(self, p=0.5):
super(RandomGray, self).__init__()
self.mat = self.params.get_constant('gray_mat', value=[[0.21, 0.21, 0.21],
[0.72, 0.72, 0.72],
[0.07, 0.07, 0.07]])
self.mat.initialize()
self.p = p

def hybrid_forward(self, F, x, mat):
cond = self.p < F.random.uniform(low=0, high=1, shape=1)
return F.contrib.cond(cond, lambda: x, lambda: F.dot(F.cast(x, dtype='float32'), mat))
13 changes: 13 additions & 0 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ def test_random_transforms():
num_apply += 1
assert_almost_equal(num_apply/float(iteration), 0.5, 0.1)

@with_seed()
def test_random_gray():
from mxnet.gluon.data.vision import transforms

transform = transforms.RandomGray(0.5)
img = mx.nd.ones((4, 4, 3), dtype='uint8')
iteration = 1000
num_apply = 0
for _ in range(iteration):
out = transform(img)
if out.dtype != np.uint8:
num_apply += 1
assert_almost_equal(num_apply/float(iteration), 0.5, 0.1)

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 0f79fca

Please sign in to comment.