diff --git a/dali/python/nvidia/dali/auto_aug/augmentations.py b/dali/python/nvidia/dali/auto_aug/augmentations.py index 3088884c9f..bfc2755858 100644 --- a/dali/python/nvidia/dali/auto_aug/augmentations.py +++ b/dali/python/nvidia/dali/auto_aug/augmentations.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +try: + import numpy as np +except ImportError: + raise RuntimeError( + "Could not import numpy. DALI's automatic augmentation examples depend on numpy. " + "Please install numpy to use the examples.") from nvidia.dali import fn from nvidia.dali import types @@ -23,7 +28,7 @@ The `@augmentation` decorator handles computation of the decorated transformations's parameter. When called, the decorated augmentation expects: -* a single positional argument: batch o samples +* a single positional argument: batch of samples * `magnitude_bin` and `num_magnitude_bins` instead of the parameter. The parameter is computed as if by calling `as_param(magnitudes[magnitude_bin] * ((-1) ** random_sign))`, where @@ -60,7 +65,7 @@ def shear_y(sample, shear, fill_value=128, interp_type=None): @augmentation(mag_range=(0., 1.), randomly_negate=True, as_param=warp_x_param) def translate_x(sample, rel_offset, shape, fill_value=128, interp_type=None): - offset = rel_offset * shape[-2] + offset = rel_offset * shape[1] mt = fn.transforms.translation(offset=offset) return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type, inverse_map=False) @@ -75,7 +80,7 @@ def translate_x_no_shape(sample, offset, fill_value=128, interp_type=None): @augmentation(mag_range=(0., 1.), randomly_negate=True, as_param=warp_y_param) def translate_y(sample, rel_offset, shape, fill_value=128, interp_type=None): - offset = rel_offset * shape[-3] + offset = rel_offset * shape[0] mt = fn.transforms.translation(offset=offset) return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type, inverse_map=False) diff --git a/dali/python/nvidia/dali/auto_aug/auto_augment.py b/dali/python/nvidia/dali/auto_aug/auto_augment.py new file mode 100644 index 0000000000..7ca11647d7 --- /dev/null +++ b/dali/python/nvidia/dali/auto_aug/auto_augment.py @@ -0,0 +1,237 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +from nvidia.dali import fn +from nvidia.dali import types +from nvidia.dali.auto_aug import augmentations as a +from nvidia.dali.auto_aug.core import _Augmentation, Policy, signed_bin +from nvidia.dali.auto_aug.core._args import forbid_unused_kwargs as _forbid_unused_kwargs +from nvidia.dali.auto_aug.core._utils import \ + parse_validate_offset as _parse_validate_offset, \ + pretty_select as _pretty_select +from nvidia.dali.data_node import DataNode as _DataNode + +try: + import numpy as np +except ImportError: + raise RuntimeError( + "Could not import numpy. DALI's automatic augmentation examples depend on numpy. " + "Please install numpy to use the examples.") + + +def auto_augment_image_net(sample: _DataNode, shape: Optional[_DataNode] = None, + fill_value: Optional[int] = 128, + interp_type: Optional[types.DALIInterpType] = None, + max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None, seed: Optional[int] = None): + """ + Applies `auto_augment_image_net_policy` in AutoAugment (https://arxiv.org/abs/1805.09501) + fashion to the provided batch of samples. + + Parameter + --------- + sample : DataNode + A batch of samples to be processed. The samples should be images of `HWC` layout, + `uint8` type and reside on GPU. + shapes: DataNode, optional + A batch of shapes of the `sample`. If specified, the magnitude of `translation` + operations depends on the image shape and spans from 0 to `max_translate_rel * shape`. + Otherwise, the magnitude range is `[0, max_translate_abs]` for any sample. + fill_value: int, optional + A value to be used as a padding for images transformed with warp_affine ops + (translation, shear and rotate). If `None` is specified, the images are padded + with the border value repeated (clamped). + interp_type: types.DALIInterpType, optional + Interpolation method used by the warp_affine ops (translation, shear and rotate). + Supported values are `types.INTERP_LINEAR` (default) and `types.INTERP_NN`. + seed: int, optional + Seed to be used to randomly sample operations (and to negate magnitudes). + + Returns + ------- + DataNode + A batch of transformed samples. + """ + aug_kwargs = {"fill_value": fill_value, "interp_type": interp_type} + use_shape = shape is not None + if use_shape: + aug_kwargs["shape"] = shape + image_net_policy = get_image_net_policy(use_shape=use_shape, + max_translate_abs=max_translate_abs, + max_translate_rel=max_translate_rel) + return apply_auto_augment(image_net_policy, sample, seed, **aug_kwargs) + + +def apply_auto_augment(policy: Policy, sample: _DataNode, seed: Optional[int] = None, + **kwargs) -> _DataNode: + """ + Applies AutoAugment (https://arxiv.org/abs/1805.09501) augmentation scheme to the + provided batch of samples. + + Parameter + --------- + policy: Policy + Set of sequences of augmentations to be applied in AutoAugment fashion. + sample : DataNode + A batch of samples to be processed. + seed: int, optional + Seed to be used to randomly sample operations (and to negate magnitudes). + kwargs: + A dictionary of extra parameters to be passed when calling augmentations. + The signature of each augmentation is checked for any extra arguments and if + the name of the argument matches one from the `kwargs`, the value is + passed as an argument. For example, some augmentations from the default + random augment suite accept `shapes`, `fill_value` and `interp_type`. + + Returns + ------- + DataNode + A batch of transformed samples. + """ + if len(policy.sub_policies) == 0: + raise Exception(f"Cannot run empty policy. Got {policy} in `apply_auto_augment` call.") + max_policy_len = max(len(sub_policy) for sub_policy in policy.sub_policies) + should_run = fn.random.uniform(range=[0, 1], shape=(max_policy_len, ), dtype=types.FLOAT) + sub_policy_id = fn.random.uniform(values=list(range(len(policy.sub_policies))), seed=seed, + dtype=types.INT32) + run_probabilities = _sub_policy_to_probability_map(policy)[sub_policy_id] + magnitude_bins = _sub_policy_to_magnitude_bin_map(policy)[sub_policy_id] + aug_ids, augmentations = _sub_policy_to_augmentation_map(policy) + aug_ids = aug_ids[sub_policy_id] + use_signed_magnitudes = any(aug.randomly_negate for aug in policy.augmentations.values()) + _forbid_unused_kwargs(augmentations, kwargs, 'apply_auto_augment') + for stage_id in range(max_policy_len): + magnitude_bin = magnitude_bins[stage_id] + if use_signed_magnitudes: + magnitude_bin = signed_bin(magnitude_bin) + if should_run[stage_id] < run_probabilities[stage_id]: + op_kwargs = dict(sample=sample, magnitude_bin=magnitude_bin, + num_magnitude_bins=policy.num_magnitude_bins, **kwargs) + sample = _pretty_select(augmentations, aug_ids[stage_id], op_kwargs, + auto_aug_name='apply_auto_augment', + ref_suite_name='get_image_net_policy') + return sample + + +def get_image_net_policy(use_shape: bool = False, max_translate_abs: int = None, + max_translate_rel: float = None) -> Policy: + """ + Creates augmentation policy tuned for the ImageNet as described in AutoAugment + (https://arxiv.org/abs/1805.09501). + The returned policy can be run with `apply_auto_augment`. + + Parameter + --------- + use_shape : bool + If true, the translation offset is computed as a percentage of the image. Useful if the + images processed with the auto augment have different shapes. If false, the offsets range + is bounded by a constant (`max_translate_abs`). + max_translate_abs: int or (int, int), optional + Only valid with use_shape=False, specifies the maximal shift (in pixels) in the translation + augmentations. If tuple is specified, the first component limits height, the second the + width. + max_translate_rel: float or (float, float), optional + Only valid with use_shape=True, specifies the maximal shift as a fraction of image shape + in the translation augmentations. If tuple is specified, the first component limits + height, the second the width. + """ + translate_y = _get_translate_y(use_shape, max_translate_abs, max_translate_rel) + shear_x = a.shear_x.augmentation((0, 0.3), True) + shear_y = a.shear_y.augmentation((0, 0.3), True) + rotate = a.rotate.augmentation((0, 30), True) + color = a.color.augmentation((0.1, 1.9), False, None) + posterize = a.posterize.augmentation((0, 4), False, a.poster_mask_uint8) + solarize = a.solarize.augmentation((0, 256), False) + solarize_add = a.solarize_add.augmentation((0, 110), False) + invert = a.invert + equalize = a.equalize + auto_contrast = a.auto_contrast + return Policy( + name="ImageNetPolicy", num_magnitude_bins=11, sub_policies=[ + [(equalize, 0.8, 1), (shear_y, 0.8, 4)], + [(color, 0.4, 9), (equalize, 0.6, 3)], + [(color, 0.4, 1), (rotate, 0.6, 8)], + [(solarize, 0.8, 3), (equalize, 0.4, 7)], + [(solarize, 0.4, 2), (solarize, 0.6, 2)], + [(color, 0.2, 0), (equalize, 0.8, 8)], + [(equalize, 0.4, 8), (solarize_add, 0.8, 3)], + [(shear_x, 0.2, 9), (rotate, 0.6, 8)], + [(color, 0.6, 1), (equalize, 1.0, 2)], + [(invert, 0.4, 9), (rotate, 0.6, 0)], + [(equalize, 1.0, 9), (shear_y, 0.6, 3)], + [(color, 0.4, 7), (equalize, 0.6, 0)], + [(posterize, 0.4, 6), (auto_contrast, 0.4, 7)], + [(solarize, 0.6, 8), (color, 0.6, 9)], + [(solarize, 0.2, 4), (rotate, 0.8, 9)], + [(rotate, 1.0, 7), (translate_y, 0.8, 9)], + [(shear_x, 0.0, 0), (solarize, 0.8, 4)], + [(shear_y, 0.8, 0), (color, 0.6, 4)], + [(color, 1.0, 0), (rotate, 0.6, 2)], + [(equalize, 0.8, 4)], + [(equalize, 1.0, 4), (auto_contrast, 0.6, 2)], + [(shear_y, 0.4, 7), (solarize_add, 0.6, 7)], + [(posterize, 0.8, 2), (solarize, 0.6, 10)], + [(solarize, 0.6, 8), (equalize, 0.6, 1)], + [(color, 0.8, 6), (rotate, 0.4, 5)], + ]) + + +def _get_translate_y(use_shape: bool = False, max_translate_abs: int = None, + max_translate_rel: float = None): + max_translate_height, _ = _parse_validate_offset(use_shape, max_translate_abs=max_translate_abs, + max_translate_rel=max_translate_rel, + default_translate_abs=250, + default_translate_rel=1.) + if use_shape: + return a.translate_y.augmentation((0, max_translate_height), True) + else: + return a.translate_y_no_shape.augmentation((0, max_translate_height), True) + + +def _sub_policy_to_probability_map(policy: Policy) -> _DataNode: + sub_policies = policy.sub_policies + max_policy_len = max(len(sub_policy) for sub_policy in sub_policies) + prob = np.array([[0. for _ in range(max_policy_len)] for _ in range(len(sub_policies))], + dtype=np.float32) + for sub_policy_id, sub_policy in enumerate(sub_policies): + for stage_idx, (aug_name, p, mag) in enumerate(sub_policy): + prob[sub_policy_id, stage_idx] = p + return types.Constant(prob) + + +def _sub_policy_to_magnitude_bin_map(policy: Policy) -> _DataNode: + sub_policies = policy.sub_policies + max_policy_len = max(len(sub_policy) for sub_policy in sub_policies) + magnitude_bin = np.array([[0 for _ in range(max_policy_len)] for _ in range(len(sub_policies))], + dtype=np.int32) + for sub_policy_id, sub_policy in enumerate(sub_policies): + for stage_idx, (aug_name, p, mag) in enumerate(sub_policy): + magnitude_bin[sub_policy_id, stage_idx] = mag + return types.Constant(magnitude_bin) + + +def _sub_policy_to_augmentation_map(policy: Policy) -> Tuple[_DataNode, List[_Augmentation]]: + sub_policies = policy.sub_policies + max_policy_len = max(len(sub_policy) for sub_policy in sub_policies) + augmentations = list(policy.augmentations.values()) + [a.identity] + identity_id = len(augmentations) - 1 + augment_to_id = {augmentation: i for i, augmentation in enumerate(augmentations)} + augments_by_id = np.array([[identity_id for _ in range(max_policy_len)] + for _ in range(len(sub_policies))], dtype=np.int32) + for sub_policy_id, sub_policy in enumerate(sub_policies): + for stage_idx, (augment, p, mag) in enumerate(sub_policy): + augments_by_id[sub_policy_id, stage_idx] = augment_to_id[augment] + return types.Constant(augments_by_id), augmentations diff --git a/dali/python/nvidia/dali/auto_aug/core/__init__.py b/dali/python/nvidia/dali/auto_aug/core/__init__.py index 3c803fd0f1..bce12701d9 100644 --- a/dali/python/nvidia/dali/auto_aug/core/__init__.py +++ b/dali/python/nvidia/dali/auto_aug/core/__init__.py @@ -15,5 +15,6 @@ from nvidia.dali.auto_aug.core._augmentation import signed_bin, Augmentation as _Augmentation from nvidia.dali.auto_aug.core.decorator import augmentation from nvidia.dali.auto_aug.core._select import select +from nvidia.dali.auto_aug.core.policy import Policy -__all__ = ("signed_bin", "augmentation", "select", "_Augmentation") +__all__ = ("signed_bin", "augmentation", "select", "Policy", "_Augmentation") diff --git a/dali/python/nvidia/dali/auto_aug/core/_augmentation.py b/dali/python/nvidia/dali/auto_aug/core/_augmentation.py index 8cfcff6046..c1a384ef10 100644 --- a/dali/python/nvidia/dali/auto_aug/core/_augmentation.py +++ b/dali/python/nvidia/dali/auto_aug/core/_augmentation.py @@ -76,8 +76,8 @@ def signed_magnitude_idx(self): return self._signed_magnitude_idx -def signed_bin(magnitude_bin: Optional[Union[int, _DataNode]], - random_sign: Optional[_DataNode] = None, seed=None) -> _SignedMagnitudeBin: +def signed_bin(magnitude_bin: Union[int, _DataNode], random_sign: Optional[_DataNode] = None, + seed=None) -> _SignedMagnitudeBin: """ Combines the `magnitude_bin` with information about the sign of the magnitude. The Augmentation wrapper can generate and handle the random sign on its own. Yet, @@ -90,8 +90,8 @@ def signed_bin(magnitude_bin: Optional[Union[int, _DataNode]], magnitude_bin: int or DataNode The magnitude bin from range `[0, num_magnitude_bins - 1]`. Can be plain int or a batch (_DataNode) of ints. - random_sign : DataNode - A batch of {0, 1} integers. For augmentations declared with `random_negate=True`, + random_sign : DataNode, optional + A batch of {0, 1} integers. For augmentations declared with `randomly_negate=True`, it determines if the magnitude is negated (for 1) or not (for 0). """ return _SignedMagnitudeBin(magnitude_bin, random_sign, seed) @@ -295,9 +295,9 @@ def _get_param(self, magnitude_bin, num_magnitude_bins): f"but unsigned `magnitude_bin` was passed to the augmentation call. " f"The augmentation will randomly negate the magnitudes manually. " f"However, for better performance, if you conditionally split batch " - f"between multiple augmentations, please call " - f"`signed_magnitude_bin = signed_bin(magnitude_bin)` and pass the " - f"signed bins instead.", Warning) + f"between multiple augmentations, it is better to call " + f"`signed_magnitude_bin = signed_bin(magnitude_bin)` before the split " + f"and pass the signed bins instead.", Warning) if self.randomly_negate: assert isinstance(magnitude_bin, _SignedMagnitudeBin) # by the two checks above if isinstance(magnitude_bin.bin, int): diff --git a/dali/python/nvidia/dali/auto_aug/core/_utils.py b/dali/python/nvidia/dali/auto_aug/core/_utils.py new file mode 100644 index 0000000000..9dbd409339 --- /dev/null +++ b/dali/python/nvidia/dali/auto_aug/core/_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from nvidia.dali.data_node import DataNode as _DataNode + +from nvidia.dali.auto_aug.core._select import select +from nvidia.dali.auto_aug.core._args import MissingArgException +from nvidia.dali.auto_aug.core._augmentation import Augmentation +import nvidia.dali.auto_aug.augmentations as a + + +def max_translate_hw(max_translate): + if isinstance(max_translate, (tuple, list)): + height, width = max_translate + return height, width + return max_translate, max_translate + + +def parse_validate_offset(use_shape, max_translate_abs=None, max_translate_rel=None, + default_translate_abs=250, default_translate_rel=1.): + # if one passes DataNode (with shapes for instance), the error message would be very vague + if not isinstance(use_shape, bool): + raise Exception(f"The `use_shape` is a flag that should be set to either True or False, " + f"got {use_shape}.") + if use_shape: + if max_translate_abs is not None: + raise Exception("The argument `max_translate_abs` cannot be used with image shapes. " + "You may use `max_translate_rel` instead.") + if max_translate_rel is None: + max_translate_rel = default_translate_rel + return max_translate_hw(max_translate_rel) + else: + if max_translate_rel is not None: + raise Exception("The argument `max_translate_rel` cannot be used without image shapes. " + "You may use `max_translate_abs` instead.") + if max_translate_abs is None: + max_translate_abs = default_translate_abs + return max_translate_hw(max_translate_abs) + + +def pretty_select(augmentations: List[Augmentation], aug_ids: _DataNode, op_kwargs, + auto_aug_name: str, ref_suite_name: str): + try: + return select(augmentations, aug_ids, **op_kwargs) + except MissingArgException as e: + if e.missing_args != ['shape'] or e.augmentation.op not in [ + a.translate_x.op, a.translate_y.op + ]: + raise + else: + raise Exception( + f"The augmentation `{e.augmentation.name}` requires `shape` argument that " + f"describes image shape (in HWC layout). Please provide it as `shape` argument " + f"to `{auto_aug_name}` call. You can get the image shape from encoded " + f"images with `fn.peek_image_shape`. Alternatively, you can use " + f"`translate_x_no_shape`/`translate_y_no_shape` that does not rely on image " + f"shape, but uses offset from fixed range: for reference see `{ref_suite_name}` " + f"and its `use_shape` argument. ") diff --git a/dali/python/nvidia/dali/auto_aug/core/policy.py b/dali/python/nvidia/dali/auto_aug/core/policy.py new file mode 100644 index 0000000000..1c9d778f42 --- /dev/null +++ b/dali/python/nvidia/dali/auto_aug/core/policy.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvidia.dali.auto_aug.core._augmentation import Augmentation +from typing import Sequence, Tuple + + +class Policy: + + def __init__(self, name: str, num_magnitude_bins: int, + sub_policies: Sequence[Sequence[Tuple[Augmentation, float, int]]]): + """ + Describes the augmentation policy as introduced in AutoAugment + (https://arxiv.org/abs/1805.09501). + + Parameter + --------- + name : str + A name of the policy, for presentation purposes. + num_magnitude_bins : int + The number of bins that augmentations' magnitude ranges should be divided into. + sub_policies: Sequence[Sequence[Tuple[Augmentation, float, int]]] + A list of sequences of transformations. For each processed sample, one of the + sequences is chosen uniformly at random. Then, the tuples from the sequence + are considered one by one. Each tuple describes what augmentation to apply at + that point, what is the probability of skipping the augmentation at that time + and what magnitude to use with the augmentation. + """ + self.name = name + self.num_magnitude_bins = num_magnitude_bins + if not isinstance(num_magnitude_bins, int) or num_magnitude_bins < 1: + raise Exception( + f"The `num_magnitude_bins` must be a positive integer, got {num_magnitude_bins}.") + if not isinstance(sub_policies, (list, tuple)): + raise Exception(f"The `sub_policies` must be a list or tuple of sub policies, " + f"got {type(sub_policies)}.") + for sub_policy in sub_policies: + if not isinstance(sub_policy, (list, tuple)): + raise Exception(f"Each sub policy must be a list or tuple, got {sub_policy}.") + for op_desc in sub_policy: + if not isinstance(op_desc, (list, tuple)) or len(op_desc) != 3: + raise Exception(f"Each operation in sub policy must be specified as a triple: " + f"(augmentation, probability, magnitude). Got {op_desc}.") + aug, p, mag = op_desc + if not isinstance(aug, Augmentation): + raise Exception( + f"Each augmentation in sub policies must be an instance of " + f"Augmentation. Got {aug}. Did you forget to use `@augmentation` " + f"decorator?") + if not isinstance(p, (float, int)) or not 0 <= p <= 1: + raise Exception( + f"Probability of applying the augmentation must be a number from " + f"`[0, 1]` range. Got {p} for augmentation `{aug.name}`.") + if not isinstance(mag, int) or not 0 <= mag < self.num_magnitude_bins: + raise Exception(f"Magnitude of the augmentation must be an integer from " + f"`[0, {num_magnitude_bins - 1}]` range. " + f"Got {mag} for augmentation `{aug.name}`.") + self.sub_policies = _sub_policy_with_unique_names(sub_policies) + + @property + def augmentations(self): + augments = set(aug for sub_policy in self.sub_policies for aug, p, mag in sub_policy) + augments = sorted(list(augments), key=lambda aug: aug.name) + return {augment.name: augment for augment in augments} + + def __repr__(self): + sub_policies_repr = ",\n\t".join( + repr([(augment.name, p, mag) for augment, p, mag in sub_policy]) + for sub_policy in self.sub_policies) + sub_policies_repr_sep = "" if not sub_policies_repr else "\n\t" + augmentations_repr = ",\n\t".join(f"'{name}': {repr(augment)}" + for name, augment in self.augmentations.items()) + augmentations_repr_sep = "" if not augmentations_repr else "\n\t" + return ( + f"Policy(name={repr(self.name)}, num_magnitude_bins={repr(self.num_magnitude_bins)}, " + f"sub_policies=[{sub_policies_repr_sep}{sub_policies_repr}], " + f"augmentations={{{augmentations_repr_sep}{augmentations_repr}}})") + + +def _sub_policy_with_unique_names( + sub_policies: Sequence[Sequence[Tuple[Augmentation, float, int]]] +) -> Tuple[Tuple[Tuple[Augmentation, float, int]]]: + """ + Check if the augmentations used in the sub-policies have unique names. + If not, rename them by adding enumeration to the names. + The aim is to have non-ambiguous presentation. + """ + all_augments = [aug for sub_policy in sub_policies for aug, p, mag in sub_policy] + augments = set(all_augments) + names = set(aug.name for aug in augments) + if len(names) == len(augments): + return tuple(tuple(sub_policy) for sub_policy in sub_policies) + num_digits = len(str(len(augments) - 1)) + remap_aug = {} + i = 0 + for augment in all_augments: + if augment not in remap_aug: + remap_aug[augment] = augment.augmentation( + name=f"{str(i).zfill(num_digits)}__{augment.name}") + i += 1 + return tuple( + tuple((remap_aug[aug], p, mag) for aug, p, mag in sub_policy) + for sub_policy in sub_policies) diff --git a/dali/test/python/auto_aug/test_auto_augment.py b/dali/test/python/auto_aug/test_auto_augment.py new file mode 100644 index 0000000000..3d274baab1 --- /dev/null +++ b/dali/test/python/auto_aug/test_auto_augment.py @@ -0,0 +1,459 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os + +import numpy as np +from scipy.stats import chisquare +from nose2.tools import params + +from nvidia.dali import fn, types +from nvidia.dali.pipeline import experimental +from nvidia.dali.auto_aug import auto_augment, augmentations as a +from nvidia.dali.auto_aug.core import augmentation, Policy + +from test_utils import get_dali_extra_path +from nose_utils import assert_raises + +data_root = get_dali_extra_path() +images_dir = os.path.join(data_root, 'db', 'single', 'jpeg') + + +def as_param_with_op_id(op_id): + + def as_param(magnitude): + return np.array([op_id, magnitude], dtype=np.int32) + + return as_param + + +@experimental.pipeline_def(enable_conditionals=True, num_threads=4, device_id=0, seed=44) +def concat_aug_pipeline(dev, policy): + sample = types.Constant(np.array([], dtype=np.int32), device=dev) + if dev == "gpu": + sample = sample.gpu() + sample = auto_augment.apply_auto_augment(policy, sample) + return fn.reshape(sample, shape=(-1, 2)) + + +def collect_sub_policy_outputs(sub_policies, num_magnitude_bins): + sub_policy_outputs = [] + for sub_policy in sub_policies: + out = [] + for aug, _, mag_bin in sub_policy: + magnitudes = aug._get_magnitudes(num_magnitude_bins) + param = aug._map_mag_to_param(magnitudes[mag_bin]) + out.append(param) + sub_policy_outputs.append(out) + return sub_policy_outputs + + +@params(*tuple(enumerate(itertools.product((True, False), (True, False), (None, 0), + (True, False))))) +def test_run_auto_aug(i, args): + uniformly_resized, use_shape, fill_value, specify_translation_bounds = args + batch_sizes = [1, 8, 7, 64, 13, 64, 128] + batch_size = batch_sizes[i % len(batch_sizes)] + + @experimental.pipeline_def(enable_conditionals=True, batch_size=batch_size, num_threads=4, + device_id=0, seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + if uniformly_resized: + image = fn.resize(image, size=(244, 244)) + extra = {} if not use_shape else {"shape": fn.peek_image_shape(encoded_image)} + if fill_value is not None: + extra["fill_value"] = fill_value + if specify_translation_bounds: + if use_shape: + extra["max_translate_rel"] = 0.9 + else: + extra["max_translate_abs"] = 400 + image = auto_augment.auto_augment_image_net(image, **extra) + return image + + p = pipeline() + p.build() + for _ in range(3): + p.run() + + +@params(*tuple(itertools.product((True, False), (0, 1), ('height', 'width', 'both')))) +def test_translation(use_shape, offset_fraction, extent): + # make sure the translation helper processes the args properly + # note, it only uses translate_y (as it is in imagenet policy) + shape = [300, 400] + fill_value = 217 + params = {} + if use_shape: + param = offset_fraction + param_name = "max_translate_rel" + else: + param_name = "max_translate_abs" + if extent == 'both': + param = shape[0] * offset_fraction + elif extent == 'height': + param = [shape[0] * offset_fraction, 0] + elif extent == 'width': + param = [0, shape[1] * offset_fraction] + else: + assert False, f"Unrecognized extent={extent}" + params[param_name] = param + translate_y = auto_augment._get_translate_y(use_shape=use_shape, **params) + policy = Policy(f"Policy_{use_shape}_{offset_fraction}", num_magnitude_bins=21, + sub_policies=[[(translate_y, 1, 20)]]) + + @experimental.pipeline_def(enable_conditionals=True, batch_size=3, num_threads=4, device_id=0, + seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + image = fn.resize(image, size=shape) + if use_shape: + return auto_augment.apply_auto_augment(policy, image, fill_value=fill_value, + shape=shape) + else: + return auto_augment.apply_auto_augment(policy, image, fill_value=fill_value) + + p = pipeline() + p.build() + output, = p.run() + output = [np.array(sample) for sample in output.as_cpu()] + for i, sample in enumerate(output): + sample = np.array(sample) + if offset_fraction == 1 and extent != "width": + assert np.all(sample == fill_value), f"sample_idx: {i}" + else: + background_count = np.sum(sample == fill_value) + assert background_count / sample.size < 0.1, \ + f"sample_idx: {i}, {background_count / sample.size}" + + +@params( + (False, "cpu", 256), + (False, "gpu", 512), + (True, "cpu", 400), + (True, "gpu", 348), +) +def test_sub_policy(randomly_negate, dev, batch_size): + + num_magnitude_bins = 10 + + @augmentation( + mag_range=(0, 9), + as_param=as_param_with_op_id(1), + param_device=dev, + ) + def first(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + @augmentation( + mag_range=(10, 19), + as_param=as_param_with_op_id(2), + randomly_negate=randomly_negate, + param_device=dev, + ) + def second(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + @augmentation( + mag_range=(20, 29), + as_param=as_param_with_op_id(3), + randomly_negate=randomly_negate, + param_device=dev, + ) + def third(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + sub_policies = [ + [(first, 1, 0), (second, 1, 5), (third, 1, 3)], + [(first, 1, 1), (third, 1, 4), (first, 1, 2)], + [(second, 1, 2), (first, 1, 3), (third, 1, 4)], + [(second, 1, 3), (third, 1, 2), (first, 1, 5)], + [(third, 1, 4), (first, 1, 1), (second, 1, 1)], + [(third, 1, 5), (second, 1, 9), (first, 1, 2)], + [(first, 1, 6), (first, 1, 1)], + [(third, 1, 7)], + [(first, 1, 8), (first, 1, 4), (second, 1, 7), (second, 1, 6)], + ] + + policy = Policy("MyPolicy", num_magnitude_bins=num_magnitude_bins, sub_policies=sub_policies) + p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy) + p.build() + + sub_policy_outputs = collect_sub_policy_outputs(sub_policies, num_magnitude_bins) + # magnitudes are chosen so that the magnitude of the first op in + # each sub-policy identifies the sub-policy + assert len({out[0][1] for out in sub_policy_outputs}) == len(sub_policy_outputs) + output_cases = {out[0][1]: np.array(out) for out in sub_policy_outputs} + + sub_policy_negation_cases = [] + for sub_policy in sub_policies: + negated = [] + for aug, _, _ in sub_policy: + if aug.randomly_negate: + negated.append((True, False)) + else: + negated.append((False, )) + sub_policy_negation_cases.append(list(itertools.product(*negated))) + assert len(sub_policy_outputs) == len(sub_policy_negation_cases) + + for _ in range(5): + output, = p.run() + if dev == "gpu": + output = output.as_cpu() + output = [np.array(sample) for sample in output] + for sample in output: + test_sample = sample if not randomly_negate else np.abs(sample) + np.testing.assert_equal(np.abs(test_sample), output_cases[test_sample[0][1]]) + for op_mag in sample: + if op_mag[1] < 0: + # the `second` and `third` augmentation are marked as randomly_negated + assert op_mag[0] in [2, 3], f"{sample}" + if randomly_negate: + # for each sub-policy, count occurrences of any possible sequence + # of magnitude signs + negation_cases = { + out[0][1]: {case: 0 + for case in cases} + for out, cases in zip(sub_policy_outputs, sub_policy_negation_cases) + } + for sample in output: + mag_signs = tuple(op_mag[1] < 0 for op_mag in sample) + negation_cases[np.abs(sample[0][1])][mag_signs] += 1 + counts, expected_counts = [], [] + for sub_policy_cases in negation_cases.values(): + expected = batch_size / (len(sub_policies) * len(sub_policy_cases)) + for count in sub_policy_cases.values(): + counts.append(count) + expected_counts.append(expected) + stat = chisquare(counts, expected_counts) + # assert that the magnitudes negation looks independently enough + # (0.05 <=), but also that it is not too ideal (i.e. like all + # cases happening exactly the expected number of times) + assert 0.05 <= stat.pvalue <= 0.95, f"{stat}" + + +@params(("cpu", ), ("gpu", )) +def test_op_skipping(dev): + + num_magnitude_bins = 16 + batch_size = 1024 + + @augmentation( + mag_range=(0, 15), + as_param=as_param_with_op_id(1), + randomly_negate=True, + param_device=dev, + ) + def first(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + @augmentation( + mag_range=(0, 15), + as_param=as_param_with_op_id(2), + randomly_negate=True, + param_device=dev, + ) + def second(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + @augmentation( + mag_range=(0, 15), + as_param=as_param_with_op_id(3), + param_device=dev, + ) + def third(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + sub_policies = [ + [(first, 0.5, 1), (first, 0.25, 2)], + [(second, 0.8, 3), (second, 0.7, 4)], + [(first, 0.9, 5), (second, 0.6, 6)], + [(second, 0.3, 7), (first, 0.25, 8)], + [(third, 1, 9), (third, 0.75, 10)], + [(third, 0.3, 11), (first, 0.22, 12)], + [(second, 0.6, 13), (third, 0, 14)], + ] + + # sub_policy_cases = [[] for _ in range(len(sub_policies))] + expected_counts = {tuple(): 0.} + for (left_aug, left_p, left_mag), (right_aug, right_p, right_mag) in sub_policies: + expected_counts[tuple()] += (1. - left_p) * (1 - right_p) / len(sub_policies) + only_left_p = left_p * (1 - right_p) / len(sub_policies) + only_right_p = (1 - left_p) * right_p / len(sub_policies) + for aug, mag, prob in [(left_aug, left_mag, only_left_p), + (right_aug, right_mag, only_right_p)]: + if not aug.randomly_negate: + expected_counts[(mag, )] = prob + else: + expected_counts[(mag, )] = prob / 2 + expected_counts[(-mag, )] = prob / 2 + sign_cases = [(-1, 1) if aug.randomly_negate else (1, ) for aug in (left_aug, right_aug)] + sign_cases = list(itertools.product(*sign_cases)) + prob = left_p * right_p / len(sub_policies) + for left_sign, right_sign in sign_cases: + mags = (left_sign * left_mag, right_sign * right_mag) + expected_counts[mags] = prob / len(sign_cases) + expected_counts = {mag: prob * batch_size for mag, prob in expected_counts.items() if prob > 0} + + policy = Policy("MyPolicy", num_magnitude_bins=num_magnitude_bins, sub_policies=sub_policies) + p = concat_aug_pipeline(batch_size=batch_size, dev=dev, policy=policy) + p.build() + + for _ in range(5): + output, = p.run() + if dev == "gpu": + output = output.as_cpu() + output = [np.array(sample) for sample in output] + actual_counts = {allowed_case: 0 for allowed_case in expected_counts} + for sample in output: + mags = tuple(int(op_mag[1]) for op_mag in sample) + actual_counts[mags] += 1 + + actual, expected = [], [] + for mags in expected_counts: + actual.append(actual_counts[mags]) + expected.append(expected_counts[mags]) + stat = chisquare(actual, expected) + # assert that the magnitudes negation looks independently enough + # (0.05 <=), but also that it is not too ideal (i.e. like all + # cases happening exactly the expected number of times) + assert 0.05 <= stat.pvalue <= 0.95, f"{stat}" + + +def test_policy_presentation(): + + empty_policy = Policy("EmptyPolicy", num_magnitude_bins=31, sub_policies=[]) + empty_policy_str = str(empty_policy) + assert "sub_policies=[]" in empty_policy_str, empty_policy_str + assert "augmentations={}" in empty_policy_str, empty_policy_str + + def get_first_augment(): + + @augmentation + def clashing_name(sample, _): + return sample + + return clashing_name + + def get_second_augment(): + + @augmentation + def clashing_name(sample, _): + return sample + + return clashing_name + + one = get_first_augment() + another = get_second_augment() + sub_policies = [[(one, 0.1, 5), (another, 0.4, 7)], [(another, 0.2, 1), (one, 0.5, 2)], + [(another, 0.7, 1)]] + policy = Policy(name="DummyPolicy", num_magnitude_bins=11, sub_policies=sub_policies) + assert policy.sub_policies[0][0][0] is policy.sub_policies[1][1][0] + assert policy.sub_policies[0][1][0] is policy.sub_policies[1][0][0] + assert policy.sub_policies[0][1][0] is policy.sub_policies[2][0][0] + assert len(sub_policies) == len(policy.sub_policies) + for sub_pol, pol_sub_pol in zip(sub_policies, policy.sub_policies): + assert len(sub_pol) == len(pol_sub_pol) + for (aug, p, mag), (pol_aug, pol_p, pol_mag) in zip(sub_pol, pol_sub_pol): + assert p == pol_p, f"({aug}, {p}, {mag}), ({pol_aug}, {pol_p}, {pol_mag})" + assert mag == pol_mag, f"({aug}, {p}, {mag}), ({pol_aug}, {pol_p}, {pol_mag})" + + @augmentation + def yet_another_aug(sample, _): + return sample + + sub_policies = [[(yet_another_aug, 0.5, i), (one.augmentation(mag_range=(0, i)), 0.24, i)] + for i in range(1, 107)] + bigger_policy = Policy(name="BiggerPolicy", num_magnitude_bins=200, sub_policies=sub_policies) + for i, (first, second) in enumerate(bigger_policy.sub_policies): + assert first[0].name == '000__yet_another_aug', f"{second[0].name}" + assert second[0].name == f'{(i + 1):03}__clashing_name', f"{second[0].name}" + + +def test_unused_arg_fail(): + + @experimental.pipeline_def(enable_conditionals=True, batch_size=5, num_threads=4, device_id=0, + seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + image_net_policy = auto_augment.get_image_net_policy() + return auto_augment.apply_auto_augment(image_net_policy, image, misspelled_kwarg=100) + + msg = "The kwarg `misspelled_kwarg` is not used by any of the augmentations." + with assert_raises(Exception, glob=msg): + pipeline() + + +def test_empty_policy_fail(): + + @experimental.pipeline_def(enable_conditionals=True, batch_size=5, num_threads=4, device_id=0, + seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + return auto_augment.apply_auto_augment(Policy("ShouldFail", 9, []), image) + + msg = ("Cannot run empty policy. Got Policy(name='ShouldFail', num_magnitude_bins=9, " + "sub_policies=[], augmentations={}) in `apply_auto_augment` call.") + with assert_raises(Exception, glob=msg): + pipeline() + + +def test_missing_shape_fail(): + + @experimental.pipeline_def(enable_conditionals=True, batch_size=5, num_threads=4, device_id=0, + seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + image_net_policy = auto_augment.get_image_net_policy(use_shape=True) + return auto_augment.apply_auto_augment(image_net_policy, image) + + msg = "`translate_y` * provide it as `shape` argument to `apply_auto_augment` call" + with assert_raises(Exception, glob=msg): + pipeline() + + +def test_wrong_sub_policy_format_fail(): + + with assert_raises(Exception, + glob="The `num_magnitude_bins` must be a positive integer, got 0"): + Policy("ShouldFail", 0.25, a.rotate) + + with assert_raises(Exception, + glob="The `sub_policies` must be a list or tuple of sub policies"): + Policy("ShouldFail", 9, a.rotate) + + with assert_raises(Exception, glob="Each sub policy must be a list or tuple"): + Policy("ShouldFail", 9, [a.rotate]) + + with assert_raises( + Exception, + glob="as a triple: (augmentation, probability, magnitude). Got Augmentation"): + Policy("ShouldFail", 9, [(a.rotate, a.shear_x)]) + + with assert_raises(Exception, glob="must be an instance of Augmentation. Got 0.5"): + Policy("ShouldFail", 9, [[(0.5, a.rotate, 3)]]) + + with assert_raises(Exception, + glob="Probability * must be a number from `[[]0, 1[]]` range. Got 2"): + Policy("ShouldFail", 9, [[(a.rotate, 2, 2)]]) + + with assert_raises(Exception, glob="Magnitude ** `[[]0, 8[]]` range. Got -1"): + Policy("ShouldFail", 9, [[(a.rotate, 1, -1)]])