diff --git a/tests/test_masks.py b/tests/test_masks.py index 2ea7a7f8..4a78d57a 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -188,6 +188,19 @@ def test_shrink_parameter(self): mp.shrink(amount=1, axis=PerturbationAxis.RANDOM) assert pnp.sum(mp.mask) == mp.mask.size - 1 + def test_shrink_entangling(self): + size = 3 + mp = self._create_circuit_with_entangling_gates(size) + mp.entangling_mask[:] = True + mp.shrink(amount=1, axis=PerturbationAxis.ENTANGLING) + assert pnp.sum(mp.entangling_mask) == mp.entangling_mask.size - 1 + + # also test in case no mask is set + mp = self._create_circuit(size) + mp.shrink(amount=1, axis=PerturbationAxis.ENTANGLING) + assert mp.entangling_mask is None + assert pnp.sum(mp.mask) == 0 # also ensure that nothing else was shrunk + def test_shrink_wrong_axis(self): mp = self._create_circuit(3) with pytest.raises(NotImplementedError):