Skip to content

Commit

Permalink
Remove probability assert (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz committed Feb 16, 2021
1 parent e14cdd8 commit 357838d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Run Flake8
run: flake8
- name: Run mypy
run: mypy .
run: mypy autoalbument

check_code_formatting:
name: Check code formatting with Black
Expand Down
2 changes: 1 addition & 1 deletion autoalbument/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.3.1"
20 changes: 15 additions & 5 deletions autoalbument/faster_autoaugment/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import random
import warnings
from copy import deepcopy
from typing import Dict

Expand All @@ -30,6 +31,9 @@
from autoalbument.faster_autoaugment.utils import MAX_VALUES_BY_INPUT_DTYPE, target_requires_grad


PROBABILITY_EPS = 0.01


class SubPolicyStage(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -60,14 +64,20 @@ def create_transform(self, input_dtype):
weights = self.weights.detach().cpu().numpy().tolist()
probabilities = [op.probability.item() for op in self.operations]
true_probabilities = [w * p for (w, p) in zip(weights, probabilities)]
assert sum(true_probabilities) <= 1.0
p_sum = sum(true_probabilities)
if p_sum > 1.0 + PROBABILITY_EPS:
warnings.warn(
f"Sum of all augmentation probabilities exceeds 1.0 and equals {p_sum}. "
"This may indicate an error in AutoAlbument. "
"Please report an issue at https://github.com/albumentations-team/autoalbument/issues.",
RuntimeWarning,
)
transforms = []
p_sum = 0
for operation, p in zip(self.operations, true_probabilities):
transforms.append(operation.create_transform(input_dtype, p))
p_sum += p
transforms.append(A.NoOp(p=1.0 - p_sum))
return OneOf(transforms, p=1)
if p_sum < 1.0:
transforms.append(A.NoOp(p=1.0 - p_sum))
return OneOf(transforms, p=1.0)


class SubPolicy(nn.Module):
Expand Down

0 comments on commit 357838d

Please sign in to comment.