diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index a3dde95fa928f..4ac6b2b4cbb54 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -14,15 +14,9 @@ from pytorch_lightning.utilities.distributed import rank_zero_warn -class WarningCache: - - def __init__(self): - self.warnings = set() +class WarningCache(set): def warn(self, m, *args, **kwargs): - if m not in self.warnings: - self.warnings.add(m) + if m not in self: + self.add(m) rank_zero_warn(m, *args, **kwargs) - - def clear(self): - self.warnings.clear()