Skip to content

Commit

Permalink
Reorder file for easier read
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
  • Loading branch information
stiepan committed Mar 9, 2023
1 parent 267d261 commit f5ff9a2
Showing 1 changed file with 75 additions and 75 deletions.
150 changes: 75 additions & 75 deletions dali/python/nvidia/dali/auto_aug/auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,81 +32,6 @@
"Please install numpy to use the examples.")


def get_image_net_policy(use_shape: bool = False, max_translate_abs: int = None,
max_translate_rel: float = None) -> Policy:
"""
Creates augmentation policy tuned for the ImageNet as described in AutoAugment
(https://arxiv.org/abs/1805.09501).
The returned policy can be run with `apply_auto_augment`.
Parameter
---------
use_shape : bool
If true, the translation offset is computed as a percentage of the image. Useful if the
images processed with the auto augment have different shapes. If false, the offsets range
is bounded by a constant (`max_translate_abs`).
max_translate_abs: int or (int, int), optional
Only valid with use_shape=False, specifies the maximal shift (in pixels) in the translation
augmentations. If tuple is specified, the first component limits height, the second the
width.
max_translate_rel: float or (float, float), optional
Only valid with use_shape=True, specifies the maximal shift as a fraction of image shape
in the translation augmentations. If tuple is specified, the first component limits
height, the second the width.
"""
translate_y = _get_translate_y(use_shape, max_translate_abs, max_translate_rel)
shear_x = a.shear_x.augmentation((0, 0.3), True)
shear_y = a.shear_y.augmentation((0, 0.3), True)
rotate = a.rotate.augmentation((0, 30), True)
color = a.color.augmentation((0.1, 1.9), False, None)
posterize = a.posterize.augmentation((0, 4), False, a.poster_mask_uint8)
solarize = a.solarize.augmentation((0, 256), False)
solarize_add = a.solarize_add.augmentation((0, 110), False)
invert = a.invert
equalize = a.equalize
auto_contrast = a.auto_contrast
return Policy(
name="ImageNetPolicy", num_magnitude_bins=11, sub_policies=[
[(equalize, 0.8, 1), (shear_y, 0.8, 4)],
[(color, 0.4, 9), (equalize, 0.6, 3)],
[(color, 0.4, 1), (rotate, 0.6, 8)],
[(solarize, 0.8, 3), (equalize, 0.4, 7)],
[(solarize, 0.4, 2), (solarize, 0.6, 2)],
[(color, 0.2, 0), (equalize, 0.8, 8)],
[(equalize, 0.4, 8), (solarize_add, 0.8, 3)],
[(shear_x, 0.2, 9), (rotate, 0.6, 8)],
[(color, 0.6, 1), (equalize, 1.0, 2)],
[(invert, 0.4, 9), (rotate, 0.6, 0)],
[(equalize, 1.0, 9), (shear_y, 0.6, 3)],
[(color, 0.4, 7), (equalize, 0.6, 0)],
[(posterize, 0.4, 6), (auto_contrast, 0.4, 7)],
[(solarize, 0.6, 8), (color, 0.6, 9)],
[(solarize, 0.2, 4), (rotate, 0.8, 9)],
[(rotate, 1.0, 7), (translate_y, 0.8, 9)],
[(shear_x, 0.0, 0), (solarize, 0.8, 4)],
[(shear_y, 0.8, 0), (color, 0.6, 4)],
[(color, 1.0, 0), (rotate, 0.6, 2)],
[(equalize, 0.8, 4)],
[(equalize, 1.0, 4), (auto_contrast, 0.6, 2)],
[(shear_y, 0.4, 7), (solarize_add, 0.6, 7)],
[(posterize, 0.8, 2), (solarize, 0.6, 10)],
[(solarize, 0.6, 8), (equalize, 0.6, 1)],
[(color, 0.8, 6), (rotate, 0.4, 5)],
])


def _get_translate_y(use_shape: bool = False, max_translate_abs: int = None,
max_translate_rel: float = None):
_, max_translate_width = _parse_validate_offset(use_shape, max_translate_abs=max_translate_abs,
max_translate_rel=max_translate_rel,
default_translate_abs=250,
default_translate_rel=1.)
if use_shape:
return a.translate_y.augmentation((0, max_translate_width), True)
else:
return a.translate_y_no_shape.augmentation((0, max_translate_width), True)


def auto_augment_image_net(sample: _DataNode, shape: Optional[_DataNode] = None,
fill_value: Optional[int] = 128,
interp_type: Optional[types.DALIInterpType] = None,
Expand Down Expand Up @@ -235,3 +160,78 @@ def sub_policy_to_augmentation_map(policy: Policy) -> Tuple[_DataNode, List[_Aug
for stage_idx, (augment, p, mag) in enumerate(sub_policy):
augments_by_id[sub_policy_id, stage_idx] = augment_to_id[augment]
return types.Constant(augments_by_id), augmentations


def get_image_net_policy(use_shape: bool = False, max_translate_abs: int = None,
max_translate_rel: float = None) -> Policy:
"""
Creates augmentation policy tuned for the ImageNet as described in AutoAugment
(https://arxiv.org/abs/1805.09501).
The returned policy can be run with `apply_auto_augment`.
Parameter
---------
use_shape : bool
If true, the translation offset is computed as a percentage of the image. Useful if the
images processed with the auto augment have different shapes. If false, the offsets range
is bounded by a constant (`max_translate_abs`).
max_translate_abs: int or (int, int), optional
Only valid with use_shape=False, specifies the maximal shift (in pixels) in the translation
augmentations. If tuple is specified, the first component limits height, the second the
width.
max_translate_rel: float or (float, float), optional
Only valid with use_shape=True, specifies the maximal shift as a fraction of image shape
in the translation augmentations. If tuple is specified, the first component limits
height, the second the width.
"""
translate_y = _get_translate_y(use_shape, max_translate_abs, max_translate_rel)
shear_x = a.shear_x.augmentation((0, 0.3), True)
shear_y = a.shear_y.augmentation((0, 0.3), True)
rotate = a.rotate.augmentation((0, 30), True)
color = a.color.augmentation((0.1, 1.9), False, None)
posterize = a.posterize.augmentation((0, 4), False, a.poster_mask_uint8)
solarize = a.solarize.augmentation((0, 256), False)
solarize_add = a.solarize_add.augmentation((0, 110), False)
invert = a.invert
equalize = a.equalize
auto_contrast = a.auto_contrast
return Policy(
name="ImageNetPolicy", num_magnitude_bins=11, sub_policies=[
[(equalize, 0.8, 1), (shear_y, 0.8, 4)],
[(color, 0.4, 9), (equalize, 0.6, 3)],
[(color, 0.4, 1), (rotate, 0.6, 8)],
[(solarize, 0.8, 3), (equalize, 0.4, 7)],
[(solarize, 0.4, 2), (solarize, 0.6, 2)],
[(color, 0.2, 0), (equalize, 0.8, 8)],
[(equalize, 0.4, 8), (solarize_add, 0.8, 3)],
[(shear_x, 0.2, 9), (rotate, 0.6, 8)],
[(color, 0.6, 1), (equalize, 1.0, 2)],
[(invert, 0.4, 9), (rotate, 0.6, 0)],
[(equalize, 1.0, 9), (shear_y, 0.6, 3)],
[(color, 0.4, 7), (equalize, 0.6, 0)],
[(posterize, 0.4, 6), (auto_contrast, 0.4, 7)],
[(solarize, 0.6, 8), (color, 0.6, 9)],
[(solarize, 0.2, 4), (rotate, 0.8, 9)],
[(rotate, 1.0, 7), (translate_y, 0.8, 9)],
[(shear_x, 0.0, 0), (solarize, 0.8, 4)],
[(shear_y, 0.8, 0), (color, 0.6, 4)],
[(color, 1.0, 0), (rotate, 0.6, 2)],
[(equalize, 0.8, 4)],
[(equalize, 1.0, 4), (auto_contrast, 0.6, 2)],
[(shear_y, 0.4, 7), (solarize_add, 0.6, 7)],
[(posterize, 0.8, 2), (solarize, 0.6, 10)],
[(solarize, 0.6, 8), (equalize, 0.6, 1)],
[(color, 0.8, 6), (rotate, 0.4, 5)],
])


def _get_translate_y(use_shape: bool = False, max_translate_abs: int = None,
max_translate_rel: float = None):
_, max_translate_width = _parse_validate_offset(use_shape, max_translate_abs=max_translate_abs,
max_translate_rel=max_translate_rel,
default_translate_abs=250,
default_translate_rel=1.)
if use_shape:
return a.translate_y.augmentation((0, max_translate_width), True)
else:
return a.translate_y_no_shape.augmentation((0, max_translate_width), True)

0 comments on commit f5ff9a2

Please sign in to comment.