diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f915022be24c..dcc819ff9d04d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PyTorchProfiler` ([#5560](https://github.com/PyTorchLightning/pytorch-lightning/pull/5560)) +- Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464)) + + ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 12f26b5726f9d..870ef819d45b6 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -258,6 +258,51 @@ In practise this means that: val = metric(pred, target) # this value can be backpropagated val = metric.compute() # this value cannot be backpropagated +****************** +Metric Arithmetics +****************** + +Metrics support most of python built-in operators for arithmetic, logic and bitwise operations. + +For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary. +It can now be done with: + +.. code-block:: python + + first_metric = MyFirstMetric() + second_metric = MySecondMetric() + + new_metric = first_metric + second_metric + +``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but +forwards only the keyword arguments that are available in respective metric's update declaration. + +Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up. + +This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats): + +* Addition (``a + b``) +* Bitwise AND (``a & b``) +* Equality (``a == b``) +* Floordivision (``a // b``) +* Greater Equal (``a >= b``) +* Greater (``a > b``) +* Less Equal (``a <= b``) +* Less (``a < b``) +* Matrix Multiplication (``a @ b``) +* Modulo (``a % b``) +* Multiplication (``a * b``) +* Inequality (``a != b``) +* Bitwise OR (``a | b``) +* Power (``a ** b``) +* Substraction (``a - b``) +* True Division (``a / b``) +* Bitwise XOR (``a ^ b``) +* Absolute Value (``abs(a)``) +* Inversion (``~a``) +* Negative Value (``neg(a)``) +* Positive Value (``pos(a)``) + **************** MetricCollection **************** diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py new file mode 100644 index 0000000000000..df98d16a3ef7e --- /dev/null +++ b/pytorch_lightning/metrics/compositional.py @@ -0,0 +1,92 @@ +from typing import Callable, Union + +import torch + +from pytorch_lightning.metrics.metric import Metric + + +class CompositionalMetric(Metric): + """Composition of two metrics with a specific operator + which will be executed upon metric's compute + + """ + + def __init__( + self, + operator: Callable, + metric_a: Union[Metric, int, float, torch.Tensor], + metric_b: Union[Metric, int, float, torch.Tensor, None], + ): + """ + + Args: + operator: the operator taking in one (if metric_b is None) + or two arguments. Will be applied to outputs of metric_a.compute() + and (optionally if metric_b is not None) metric_b.compute() + metric_a: first metric whose compute() result is the first argument of operator + metric_b: second metric whose compute() result is the second argument of operator. + For operators taking in only one input, this should be None + """ + super().__init__() + + self.op = operator + + if isinstance(metric_a, torch.Tensor): + self.register_buffer("metric_a", metric_a) + else: + self.metric_a = metric_a + + if isinstance(metric_b, torch.Tensor): + self.register_buffer("metric_b", metric_b) + else: + self.metric_b = metric_b + + def _sync_dist(self, dist_sync_fn=None): + # No syncing required here. syncing will be done in metric_a and metric_b + pass + + def update(self, *args, **kwargs): + if isinstance(self.metric_a, Metric): + self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) + + if isinstance(self.metric_b, Metric): + self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) + + def compute(self): + + # also some parsing for kwargs? + if isinstance(self.metric_a, Metric): + val_a = self.metric_a.compute() + else: + val_a = self.metric_a + + if isinstance(self.metric_b, Metric): + val_b = self.metric_b.compute() + else: + val_b = self.metric_b + + if val_b is None: + return self.op(val_a) + + return self.op(val_a, val_b) + + def reset(self): + if isinstance(self.metric_a, Metric): + self.metric_a.reset() + + if isinstance(self.metric_b, Metric): + self.metric_b.reset() + + def persistent(self, mode: bool = False): + if isinstance(self.metric_a, Metric): + self.metric_a.persistent(mode=mode) + if isinstance(self.metric_b, Metric): + self.metric_b.persistent(mode=mode) + + def __repr__(self): + repr_str = ( + self.__class__.__name__ + + f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" + ) + + return repr_str diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 5c8aaefc2084a..794e696e98f8a 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -306,6 +306,200 @@ def state_dict(self, *args, **kwargs): state_dict.update({key: current_val}) return state_dict + def _filter_kwargs(self, **kwargs): + """ filter kwargs such that they match the update signature of the metric """ + + # filter all parameters based on update signature except those of + # type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs) + _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + filtered_kwargs = { + k: v + for k, v in kwargs.items() + if k in self._update_signature.parameters.keys() + and self._update_signature.parameters[k].kind not in _params + } + + # if no kwargs filtered, return al kwargs as default + if not filtered_kwargs: + filtered_kwargs = kwargs + return filtered_kwargs + + def __hash__(self): + hash_vals = [self.__class__.__name__] + + for key in self._defaults.keys(): + hash_vals.append(getattr(self, key)) + + return hash(tuple(hash_vals)) + + def __add__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.add, self, other) + + def __and__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_and, self, other) + + def __eq__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.eq, self, other) + + def __floordiv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.floor_divide, self, other) + + def __ge__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.ge, self, other) + + def __gt__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.gt, self, other) + + def __le__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.le, self, other) + + def __lt__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.lt, self, other) + + def __matmul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.matmul, self, other) + + def __mod__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.fmod, self, other) + + def __mul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.mul, self, other) + + def __ne__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.ne, self, other) + + def __or__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_or, self, other) + + def __pow__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.pow, self, other) + + def __radd__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.add, other, self) + + def __rand__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + # swap them since bitwise_and only supports that way and it's commutative + return CompositionalMetric(torch.bitwise_and, self, other) + + def __rfloordiv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.floor_divide, other, self) + + def __rmatmul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.matmul, other, self) + + def __rmod__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.fmod, other, self) + + def __rmul__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.mul, other, self) + + def __ror__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_or, other, self) + + def __rpow__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.pow, other, self) + + def __rsub__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.sub, other, self) + + def __rtruediv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.true_divide, other, self) + + def __rxor__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_xor, other, self) + + def __sub__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.sub, self, other) + + def __truediv__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.true_divide, self, other) + + def __xor__(self, other: Any): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_xor, self, other) + + def __abs__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.abs, self, None) + + def __inv__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.bitwise_not, self, None) + + def __invert__(self): + return self.__inv__() + + def __neg__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(_neg, self, None) + + def __pos__(self): + from pytorch_lightning.metrics.compositional import CompositionalMetric + + return CompositionalMetric(torch.abs, self, None) + + +def _neg(tensor: torch.Tensor): + return -torch.abs(tensor) + class MetricCollection(nn.ModuleDict): """ @@ -342,30 +536,29 @@ class MetricCollection(nn.ModuleDict): {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} """ + def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): super().__init__() if isinstance(metrics, dict): # Check all values are metrics for name, metric in metrics.items(): if not isinstance(metric, Metric): - raise ValueError(f'Value {metric} belonging to key {name}' - ' is not an instance of `pl.metrics.Metric`') + raise ValueError( + f"Value {metric} belonging to key {name}" " is not an instance of `pl.metrics.Metric`" + ) self[name] = metric elif isinstance(metrics, (tuple, list)): for metric in metrics: if not isinstance(metric, Metric): - raise ValueError(f'Input {metric} to `MetricCollection` is not a instance' - ' of `pl.metrics.Metric`') + raise ValueError( + f"Input {metric} to `MetricCollection` is not a instance" " of `pl.metrics.Metric`" + ) name = metric.__class__.__name__ if name in self: - raise ValueError(f'Encountered two metrics both named {name}') + raise ValueError(f"Encountered two metrics both named {name}") self[name] = metric else: - raise ValueError('Unknown input to MetricCollection.') - - def _filter_kwargs(self, metric: Metric, **kwargs): - """ filter kwargs such that they match the update signature of the metric """ - return {k: v for k, v in kwargs.items() if k in metric._update_signature.parameters.keys()} + raise ValueError("Unknown input to MetricCollection.") def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 """ @@ -373,7 +566,7 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - return {k: m(*args, **self._filter_kwargs(m, **kwargs)) for k, m in self.items()} + return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} def update(self, *args, **kwargs): # pylint: disable=E0202 """ @@ -382,7 +575,7 @@ def update(self, *args, **kwargs): # pylint: disable=E0202 will be filtered based on the signature of the individual metric. """ for _, m in self.items(): - m_kwargs = self._filter_kwargs(m, **kwargs) + m_kwargs = m._filter_kwargs(**kwargs) m.update(*args, **m_kwargs) def compute(self) -> Dict[str, Any]: @@ -398,8 +591,8 @@ def clone(self): return deepcopy(self) def persistent(self, mode: bool = True): - """ Method for post-init to change if metric states should be saved to - its state_dict + """Method for post-init to change if metric states should be saved to + its state_dict """ for _, m in self.items(): m.persistent(mode) diff --git a/tests/metrics/test_composition.py b/tests/metrics/test_composition.py new file mode 100644 index 0000000000000..087dee521d694 --- /dev/null +++ b/tests/metrics/test_composition.py @@ -0,0 +1,485 @@ +from operator import neg, pos + +import pytest +import torch + +from pytorch_lightning.metrics.compositional import CompositionalMetric +from pytorch_lightning.metrics.metric import Metric + + +class DummyMetric(Metric): + def __init__(self, val_to_return): + super().__init__() + self._num_updates = 0 + self._val_to_return = val_to_return + + def update(self, *args, **kwargs) -> None: + self._num_updates += 1 + + def compute(self): + return torch.tensor(self._val_to_return) + + def reset(self): + self._num_updates = 0 + return super().reset() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(4)), + (2, torch.tensor(4)), + (2.0, torch.tensor(4.0)), + (torch.tensor(2), torch.tensor(4)), + ], +) +def test_metrics_add(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_add = first_metric + second_operand + final_radd = second_operand + first_metric + + assert isinstance(final_add, CompositionalMetric) + assert isinstance(final_radd, CompositionalMetric) + + assert torch.allclose(expected_result, final_add.compute()) + assert torch.allclose(expected_result, final_radd.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))], +) +def test_metrics_and(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_and = first_metric & second_operand + final_rand = second_operand & first_metric + + assert isinstance(final_and, CompositionalMetric) + assert isinstance(final_rand, CompositionalMetric) + + assert torch.allclose(expected_result, final_and.compute()) + assert torch.allclose(expected_result, final_rand.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(True)), + (2, torch.tensor(True)), + (2.0, torch.tensor(True)), + (torch.tensor(2), torch.tensor(True)), + ], +) +def test_metrics_eq(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_eq = first_metric == second_operand + + assert isinstance(final_eq, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_eq.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(2)), + (2, torch.tensor(2)), + (2.0, torch.tensor(2.0)), + (torch.tensor(2), torch.tensor(2)), + ], +) +def test_metrics_floordiv(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_floordiv = first_metric // second_operand + + assert isinstance(final_floordiv, CompositionalMetric) + + assert torch.allclose(expected_result, final_floordiv.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(True)), + (2, torch.tensor(True)), + (2.0, torch.tensor(True)), + (torch.tensor(2), torch.tensor(True)), + ], +) +def test_metrics_ge(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_ge = first_metric >= second_operand + + assert isinstance(final_ge, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_ge.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(True)), + (2, torch.tensor(True)), + (2.0, torch.tensor(True)), + (torch.tensor(2), torch.tensor(True)), + ], +) +def test_metrics_gt(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_gt = first_metric > second_operand + + assert isinstance(final_gt, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_gt.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(False)), + (2, torch.tensor(False)), + (2.0, torch.tensor(False)), + (torch.tensor(2), torch.tensor(False)), + ], +) +def test_metrics_le(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_le = first_metric <= second_operand + + assert isinstance(final_le, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_le.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(False)), + (2, torch.tensor(False)), + (2.0, torch.tensor(False)), + (torch.tensor(2), torch.tensor(False)), + ], +) +def test_metrics_lt(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_lt = first_metric < second_operand + + assert isinstance(final_lt, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_lt.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))], +) +def test_metrics_matmul(second_operand, expected_result): + first_metric = DummyMetric([2, 2, 2]) + + final_matmul = first_metric @ second_operand + + assert isinstance(final_matmul, CompositionalMetric) + + assert torch.allclose(expected_result, final_matmul.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(1)), + (2, torch.tensor(1)), + (2.0, torch.tensor(1)), + (torch.tensor(2), torch.tensor(1)), + ], +) +def test_metrics_mod(second_operand, expected_result): + first_metric = DummyMetric(5) + + final_mod = first_metric % second_operand + + assert isinstance(final_mod, CompositionalMetric) + + assert torch.allclose(expected_result, final_mod.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(4)), + (2, torch.tensor(4)), + (2.0, torch.tensor(4.0)), + (torch.tensor(2), torch.tensor(4)), + ], +) +def test_metrics_mul(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_mul = first_metric * second_operand + final_rmul = second_operand * first_metric + + assert isinstance(final_mul, CompositionalMetric) + assert isinstance(final_rmul, CompositionalMetric) + + assert torch.allclose(expected_result, final_mul.compute()) + assert torch.allclose(expected_result, final_rmul.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(False)), + (2, torch.tensor(False)), + (2.0, torch.tensor(False)), + (torch.tensor(2), torch.tensor(False)), + ], +) +def test_metrics_ne(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_ne = first_metric != second_operand + + assert isinstance(final_ne, CompositionalMetric) + + # can't use allclose for bool tensors + assert (expected_result == final_ne.compute()).all() + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))], +) +def test_metrics_or(second_operand, expected_result): + first_metric = DummyMetric([-1, -2, 3]) + + final_or = first_metric | second_operand + final_ror = second_operand | first_metric + + assert isinstance(final_or, CompositionalMetric) + assert isinstance(final_ror, CompositionalMetric) + + assert torch.allclose(expected_result, final_or.compute()) + assert torch.allclose(expected_result, final_ror.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(4)), + (2, torch.tensor(4)), + (2.0, torch.tensor(4.0)), + (torch.tensor(2), torch.tensor(4)), + ], +) +def test_metrics_pow(second_operand, expected_result): + first_metric = DummyMetric(2) + + final_pow = first_metric ** second_operand + + assert isinstance(final_pow, CompositionalMetric) + + assert torch.allclose(expected_result, final_pow.compute()) + + +@pytest.mark.parametrize( + ["first_operand", "expected_result"], + [(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))], +) +def test_metrics_rfloordiv(first_operand, expected_result): + second_operand = DummyMetric(2) + + final_rfloordiv = first_operand // second_operand + + assert isinstance(final_rfloordiv, CompositionalMetric) + assert torch.allclose(expected_result, final_rfloordiv.compute()) + + +@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor([2, 2, 2]), torch.tensor(12))]) +def test_metrics_rmatmul(first_operand, expected_result): + second_operand = DummyMetric([2, 2, 2]) + + final_rmatmul = first_operand @ second_operand + + assert isinstance(final_rmatmul, CompositionalMetric) + + assert torch.allclose(expected_result, final_rmatmul.compute()) + + +@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor(2), torch.tensor(2))]) +def test_metrics_rmod(first_operand, expected_result): + second_operand = DummyMetric(5) + + final_rmod = first_operand % second_operand + + assert isinstance(final_rmod, CompositionalMetric) + + assert torch.allclose(expected_result, final_rmod.compute()) + + +@pytest.mark.parametrize( + ["first_operand", "expected_result"], + [(DummyMetric(2), torch.tensor(4)), (2, torch.tensor(4)), (2.0, torch.tensor(4.0))], +) +def test_metrics_rpow(first_operand, expected_result): + second_operand = DummyMetric(2) + + final_rpow = first_operand ** second_operand + + assert isinstance(final_rpow, CompositionalMetric) + + assert torch.allclose(expected_result, final_rpow.compute()) + + +@pytest.mark.parametrize( + ["first_operand", "expected_result"], + [ + (DummyMetric(3), torch.tensor(1)), + (3, torch.tensor(1)), + (3.0, torch.tensor(1.0)), + (torch.tensor(3), torch.tensor(1)), + ], +) +def test_metrics_rsub(first_operand, expected_result): + second_operand = DummyMetric(2) + + final_rsub = first_operand - second_operand + + assert isinstance(final_rsub, CompositionalMetric) + + assert torch.allclose(expected_result, final_rsub.compute()) + + +@pytest.mark.parametrize( + ["first_operand", "expected_result"], + [ + (DummyMetric(6), torch.tensor(2.0)), + (6, torch.tensor(2.0)), + (6.0, torch.tensor(2.0)), + (torch.tensor(6), torch.tensor(2.0)), + ], +) +def test_metrics_rtruediv(first_operand, expected_result): + second_operand = DummyMetric(3) + + final_rtruediv = first_operand / second_operand + + assert isinstance(final_rtruediv, CompositionalMetric) + + assert torch.allclose(expected_result, final_rtruediv.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(2), torch.tensor(1)), + (2, torch.tensor(1)), + (2.0, torch.tensor(1.0)), + (torch.tensor(2), torch.tensor(1)), + ], +) +def test_metrics_sub(second_operand, expected_result): + first_metric = DummyMetric(3) + + final_sub = first_metric - second_operand + + assert isinstance(final_sub, CompositionalMetric) + + assert torch.allclose(expected_result, final_sub.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [ + (DummyMetric(3), torch.tensor(2.0)), + (3, torch.tensor(2.0)), + (3.0, torch.tensor(2.0)), + (torch.tensor(3), torch.tensor(2.0)), + ], +) +def test_metrics_truediv(second_operand, expected_result): + first_metric = DummyMetric(6) + + final_truediv = first_metric / second_operand + + assert isinstance(final_truediv, CompositionalMetric) + + assert torch.allclose(expected_result, final_truediv.compute()) + + +@pytest.mark.parametrize( + ["second_operand", "expected_result"], + [(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))], +) +def test_metrics_xor(second_operand, expected_result): + first_metric = DummyMetric([-1, -2, 3]) + + final_xor = first_metric ^ second_operand + final_rxor = second_operand ^ first_metric + + assert isinstance(final_xor, CompositionalMetric) + assert isinstance(final_rxor, CompositionalMetric) + + assert torch.allclose(expected_result, final_xor.compute()) + assert torch.allclose(expected_result, final_rxor.compute()) + + +def test_metrics_abs(): + first_metric = DummyMetric(-1) + + final_abs = abs(first_metric) + + assert isinstance(final_abs, CompositionalMetric) + + assert torch.allclose(torch.tensor(1), final_abs.compute()) + + +def test_metrics_invert(): + first_metric = DummyMetric(1) + + final_inverse = ~first_metric + assert isinstance(final_inverse, CompositionalMetric) + assert torch.allclose(torch.tensor(-2), final_inverse.compute()) + + +def test_metrics_neg(): + first_metric = DummyMetric(1) + + final_neg = neg(first_metric) + assert isinstance(final_neg, CompositionalMetric) + assert torch.allclose(torch.tensor(-1), final_neg.compute()) + + +def test_metrics_pos(): + first_metric = DummyMetric(-1) + + final_pos = pos(first_metric) + assert isinstance(final_pos, CompositionalMetric) + assert torch.allclose(torch.tensor(1), final_pos.compute()) + + +def test_compositional_metrics_update(): + + compos = DummyMetric(5) + DummyMetric(4) + + assert isinstance(compos, CompositionalMetric) + compos.update() + compos.update() + compos.update() + + assert isinstance(compos.metric_a, DummyMetric) + assert isinstance(compos.metric_b, DummyMetric) + + assert compos.metric_a._num_updates == 3 + assert compos.metric_b._num_updates == 3