Skip to content

Commit

Permalink
Add shortcut to find out if an AbstractArray is differentiable (#784)
Browse files Browse the repository at this point in the history
* Define AbstractArray.is_differentiable

* Remove warnings.simplefilter that was overriding pytest filterwarnings

* Rename is_differentiable -> requires_grad
  • Loading branch information
HGSilveri authored Dec 20, 2024
1 parent 21c8d46 commit 8db708e
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 50 deletions.
10 changes: 6 additions & 4 deletions pulser-core/pulser/math/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def is_tensor(self) -> bool:
"""Whether the stored array is a tensor."""
return self.has_torch() and isinstance(self._array, torch.Tensor)

@property
def requires_grad(self) -> bool:
"""Whether the stored array is a tensor that needs a gradient."""
return self.is_tensor and cast(torch.Tensor, self._array).requires_grad

def astype(self, dtype: DTypeLike) -> AbstractArray:
"""Casts the data type of the array contents."""
if self.is_tensor:
Expand Down Expand Up @@ -271,10 +276,7 @@ def __setitem__(self, indices: Any, values: AbstractArrayLike) -> None:
self._process_indices(indices)
] = values # type: ignore[assignment]
except RuntimeError as e:
if (
self.is_tensor
and cast(torch.Tensor, self._array).requires_grad
):
if self.requires_grad:
raise RuntimeError(
"Failed to modify a tensor that requires grad in place."
) from e
Expand Down
1 change: 0 additions & 1 deletion pulser-core/pulser/register/base_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
)
self._ids: tuple[QubitId, ...] = tuple(qubits.keys())
if any(not isinstance(id, str) for id in self._ids):
warnings.simplefilter("always")
warnings.warn(
"Usage of `int`s or any non-`str`types as `QubitId`s will be "
"deprecated. Define your `QubitId`s as `str`s, prefer setting "
Expand Down
6 changes: 2 additions & 4 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ def test_modulation(channel, tr, eom, side_buffer_len, requires_grad):
tr,
tr,
)
if requires_grad:
assert out_.as_tensor().requires_grad
assert out_.requires_grad == requires_grad

wf2 = BlackmanWaveform(800, wf_vals[1])
out_ = channel.modulate(wf2.samples, eom=eom)
Expand All @@ -302,8 +301,7 @@ def test_modulation(channel, tr, eom, side_buffer_len, requires_grad):
side_buffer_len,
side_buffer_len,
)
if requires_grad:
assert out_.as_tensor().requires_grad
assert out_.requires_grad == requires_grad


@pytest.mark.parametrize(
Expand Down
6 changes: 2 additions & 4 deletions tests/test_eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def calc_offset(amp):
]
)
assert calculated_det_off == min(det_off_options, key=abs)
if requires_grad:
assert calculated_det_off.as_tensor().requires_grad
assert calculated_det_off.requires_grad == requires_grad

# Case where the EOM pulses are off-resonant
detuning_on = detuning_on + 1.0
Expand All @@ -210,5 +209,4 @@ def calc_offset(amp):
assert off_options[0] == eom_.calculate_detuning_off(
amp, detuning_on, optimal_detuning_off=0.0
)
if requires_grad:
assert off_options.as_tensor().requires_grad
assert off_options.requires_grad == requires_grad
10 changes: 4 additions & 6 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def test_pad(cast_to, requires_grad):
arr = torch.tensor(arr, requires_grad=requires_grad)

def check_match(arr1: pm.AbstractArray, arr2):
if requires_grad:
assert arr1.as_tensor().requires_grad
assert arr1.requires_grad == requires_grad
np.testing.assert_array_equal(
arr1.as_array(detach=requires_grad), arr2
)
Expand Down Expand Up @@ -260,8 +259,7 @@ def test_items(self, use_tensor, requires_grad, indices):
assert item == val[i]
assert isinstance(item, pm.AbstractArray)
assert item.is_tensor == use_tensor
if use_tensor:
assert item.as_tensor().requires_grad == requires_grad
assert item.requires_grad == requires_grad

# setitem
if not requires_grad:
Expand Down Expand Up @@ -292,8 +290,8 @@ def test_items(self, use_tensor, requires_grad, indices):
new_val[indices] = 0.0
assert np.all(arr_np == new_val)
assert arr_np.is_tensor
# The resulting tensor requires grad if the assing one did
assert arr_np.as_tensor().requires_grad == requires_grad
# The resulting tensor requires grad if the assigned one did
assert arr_np.requires_grad == requires_grad

@pytest.mark.parametrize("scalar", [False, True])
@pytest.mark.parametrize(
Expand Down
7 changes: 2 additions & 5 deletions tests/test_parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ def test_var_diff(a, b, requires_grad):
b._assign(torch.tensor([-1.0, 1.0], requires_grad=requires_grad))

for var in [a, b]:
assert (
a.value is not None
and a.value.as_tensor().requires_grad == requires_grad
)
assert a.value is not None and a.value.requires_grad == requires_grad


def test_varitem(a, b, d):
Expand Down Expand Up @@ -167,7 +164,7 @@ def test_paramobj(bwf, t, a, b):
def test_opsupport(a, b, with_diff_tensor):
def check_var_grad(var):
if with_diff_tensor:
assert var.build().as_tensor().requires_grad
assert var.build().requires_grad

a._assign(-2.0)
if with_diff_tensor:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def test_eq():


def _assert_pulse_requires_grad(pulse: Pulse, invert: bool = False) -> None:
assert pulse.amplitude.samples.as_tensor().requires_grad == (not invert)
assert pulse.detuning.samples.as_tensor().requires_grad == (not invert)
assert pulse.phase.as_tensor().requires_grad == (not invert)
assert pulse.amplitude.samples.requires_grad == (not invert)
assert pulse.detuning.samples.requires_grad == (not invert)
assert pulse.phase.requires_grad == (not invert)


@pytest.mark.parametrize("requires_grad", [True, False])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,9 @@ def _assert_reg_requires_grad(
) -> None:
for coords in reg.qubits.values():
if invert:
assert not coords.as_tensor().requires_grad
assert not coords.requires_grad
else:
assert coords.is_tensor and coords.as_tensor().requires_grad
assert coords.is_tensor and coords.requires_grad


@pytest.mark.parametrize(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,12 +2886,12 @@ def test_sequence_diff(device, parametrized, with_modulation, with_eom):

seq_samples = sample(seq, modulation=with_modulation)
ryd_ch_samples = seq_samples.channel_samples["ryd_global"]
assert ryd_ch_samples.amp.as_tensor().requires_grad
assert ryd_ch_samples.det.as_tensor().requires_grad
assert ryd_ch_samples.phase.as_tensor().requires_grad
assert ryd_ch_samples.amp.requires_grad
assert ryd_ch_samples.det.requires_grad
assert ryd_ch_samples.phase.requires_grad
if "dmm_0" in seq_samples.channel_samples:
dmm_ch_samples = seq_samples.channel_samples["dmm_0"]
# Only detuning is modulated
assert not dmm_ch_samples.amp.as_tensor().requires_grad
assert dmm_ch_samples.det.as_tensor().requires_grad
assert not dmm_ch_samples.phase.as_tensor().requires_grad
assert not dmm_ch_samples.amp.requires_grad
assert dmm_ch_samples.det.requires_grad
assert not dmm_ch_samples.phase.requires_grad
10 changes: 5 additions & 5 deletions tests/test_sequence_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,11 @@ def test_phase_modulation(off_center, with_diff):
seq_samples = sample(seq).channel_samples["rydberg_global"]

if with_diff:
assert full_phase.samples.as_tensor().requires_grad
assert not seq_samples.amp.as_tensor().requires_grad
assert seq_samples.det.as_tensor().requires_grad
assert seq_samples.phase.as_tensor().requires_grad
assert seq_samples.phase_modulation.as_tensor().requires_grad
assert full_phase.samples.requires_grad
assert not seq_samples.amp.requires_grad
assert seq_samples.det.requires_grad
assert seq_samples.phase.requires_grad
assert seq_samples.phase_modulation.requires_grad

np.testing.assert_allclose(
seq_samples.phase_modulation.as_array(detach=with_diff)
Expand Down
14 changes: 4 additions & 10 deletions tests/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,26 +490,20 @@ def test_waveform_diff(

samples_tensor = wf.samples.as_tensor()
assert samples_tensor.requires_grad == requires_grad
assert (
wf.modulated_samples(rydberg_global).as_tensor().requires_grad
== requires_grad
)
assert wf.modulated_samples(rydberg_global).requires_grad == requires_grad
wfx2_tensor = (-wf * 2).samples.as_tensor()
assert torch.equal(wfx2_tensor, samples_tensor * -2.0)
assert wfx2_tensor.requires_grad == requires_grad

wfdiv2 = wf / torch.tensor(2.0, requires_grad=True)
assert torch.equal(wfdiv2.samples.as_tensor(), samples_tensor / 2.0)
# Should always be true because it was divided by diff tensor
assert wfdiv2.samples.as_tensor().requires_grad
assert wfdiv2.samples.requires_grad

assert wf[-1].as_tensor().requires_grad == requires_grad
assert wf[-1].requires_grad == requires_grad

try:
assert (
wf.change_duration(1000).samples.as_tensor().requires_grad
== requires_grad
)
assert wf.change_duration(1000).samples.requires_grad == requires_grad
except NotImplementedError:
pass

Expand Down

0 comments on commit 8db708e

Please sign in to comment.