diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d197a1d..c076ab5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: - name: Run Flake8 run: flake8 - name: Run mypy - run: mypy . + run: mypy autoalbument check_code_formatting: name: Check code formatting with Black diff --git a/autoalbument/__init__.py b/autoalbument/__init__.py index 493f741..260c070 100644 --- a/autoalbument/__init__.py +++ b/autoalbument/__init__.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.3.1" diff --git a/autoalbument/faster_autoaugment/policy.py b/autoalbument/faster_autoaugment/policy.py index 0a132af..70c8b75 100644 --- a/autoalbument/faster_autoaugment/policy.py +++ b/autoalbument/faster_autoaugment/policy.py @@ -4,6 +4,7 @@ """ import random +import warnings from copy import deepcopy from typing import Dict @@ -30,6 +31,9 @@ from autoalbument.faster_autoaugment.utils import MAX_VALUES_BY_INPUT_DTYPE, target_requires_grad +PROBABILITY_EPS = 0.01 + + class SubPolicyStage(nn.Module): def __init__( self, @@ -60,14 +64,20 @@ def create_transform(self, input_dtype): weights = self.weights.detach().cpu().numpy().tolist() probabilities = [op.probability.item() for op in self.operations] true_probabilities = [w * p for (w, p) in zip(weights, probabilities)] - assert sum(true_probabilities) <= 1.0 + p_sum = sum(true_probabilities) + if p_sum > 1.0 + PROBABILITY_EPS: + warnings.warn( + f"Sum of all augmentation probabilities exceeds 1.0 and equals {p_sum}. " + "This may indicate an error in AutoAlbument. " + "Please report an issue at https://github.com/albumentations-team/autoalbument/issues.", + RuntimeWarning, + ) transforms = [] - p_sum = 0 for operation, p in zip(self.operations, true_probabilities): transforms.append(operation.create_transform(input_dtype, p)) - p_sum += p - transforms.append(A.NoOp(p=1.0 - p_sum)) - return OneOf(transforms, p=1) + if p_sum < 1.0: + transforms.append(A.NoOp(p=1.0 - p_sum)) + return OneOf(transforms, p=1.0) class SubPolicy(nn.Module):