Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support differentiability through Torch tensors #703

Merged
merged 53 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0803ace
Defining pulser.math and AbstractArray
HGSilveri May 7, 2024
f6a701a
POC: Differentiable constant pulse amp
HGSilveri May 7, 2024
3c10cfb
Fix typing in waveforms
HGSilveri May 10, 2024
0941ffa
Fix all typing errors in POC
HGSilveri May 10, 2024
c4428a8
Pass all existing UTs
HGSilveri May 10, 2024
fb44b7e
Pass all UTs without array support
HGSilveri May 13, 2024
c1f5771
Fix typing
HGSilveri May 13, 2024
da4d048
All tests pass with torch installed
HGSilveri May 13, 2024
44c58dd
Add support for pulser-diff backend (#686)
vytautas-a May 29, 2024
7b09ea6
Fix UTs after pulser-diff changes
HGSilveri May 30, 2024
6b5823b
Avoid using AbstractArrayLike outside of pulser.math
HGSilveri May 30, 2024
2b4db0a
Preserve gradient in EOM mode
HGSilveri Jun 3, 2024
10d52df
Add torch as an optional requirement
HGSilveri Jun 3, 2024
7d1d1ed
Support waveform multiplication with abstract array
HGSilveri Jun 3, 2024
b56b2bb
Explicitly marking the differentiable parameters
HGSilveri Jun 4, 2024
5d9ce35
Remove __array_wrap__
HGSilveri Jun 4, 2024
d704e31
Pass relevant UTs without array support
HGSilveri Jun 13, 2024
b3eed41
Support new features
HGSilveri Jun 13, 2024
283f840
Using pm.Differentiable whenever possible
HGSilveri Jun 13, 2024
a4bf66b
Simplifying Waveform.__getitem__() type hint
HGSilveri Jun 17, 2024
4d7af5d
UTs for new features outside of pulser.math
HGSilveri Jun 18, 2024
6a0f3be
Write torch UTs for registers
HGSilveri Jun 24, 2024
3a8085a
Write UTs for waveforms
HGSilveri Jun 24, 2024
0a73c80
UTs for pulse
HGSilveri Jun 25, 2024
95a63e2
UTs for EOM
HGSilveri Jun 25, 2024
9525725
UTs on internal functionality
HGSilveri Jun 25, 2024
46c25a8
UTs for Sequence with autograd
HGSilveri Jun 26, 2024
fd057bc
Implicitly cover math functions
HGSilveri Jun 28, 2024
5a33903
Removing AbstractArray.__hash__() and differentiable phase shifts
HGSilveri Jun 28, 2024
93ab7ec
Finish unit tests
HGSilveri Jul 1, 2024
c20a61b
Update CI to run tests with and without torch
HGSilveri Jul 1, 2024
6c220a1
Fix CI errors
HGSilveri Jul 1, 2024
a724628
Fix failing no-torch UT
HGSilveri Jul 1, 2024
70a818d
Minor corrections
HGSilveri Jul 2, 2024
97f4045
Include pulser[torch] installation in the README
HGSilveri Jul 2, 2024
7af901d
Merge branch 'develop' into abstract-array
HGSilveri Jul 2, 2024
64436e4
Fix warning in UT after merge
HGSilveri Jul 2, 2024
03838d5
Merge branch 'develop' into abstract-array
HGSilveri Jul 2, 2024
8fe6079
Merge branch 'develop' into abstract-array
HGSilveri Jul 18, 2024
a04c600
Merge branch 'develop' into abstract-array
HGSilveri Jul 19, 2024
5f071cc
Merge branch 'develop' into abstract-array
HGSilveri Jul 23, 2024
4ad08e2
Incorporating the latest changes
HGSilveri Jul 23, 2024
d0992b5
Merge branch 'develop' into abstract-array
HGSilveri Jul 29, 2024
c258ceb
Fix typing
HGSilveri Jul 29, 2024
4831021
Merge branch 'develop' into abstract-array
HGSilveri Aug 2, 2024
e0fd274
Addressing review comments
HGSilveri Sep 10, 2024
544063e
Merge branch 'develop' into abstract-array
HGSilveri Sep 10, 2024
5b9870d
Including `detach()` in Differentiable protocol
HGSilveri Sep 12, 2024
0b184d7
Differentiable -> TensorLike
HGSilveri Sep 12, 2024
06729c3
Tentatively allow waveform division by array
HGSilveri Sep 12, 2024
347b5ff
Full coverage
HGSilveri Sep 12, 2024
08bd2cd
Merge branch 'develop' into abstract-array
HGSilveri Sep 16, 2024
bf5d198
Merge branch 'develop' into abstract-array
HGSilveri Sep 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ per-file-ignores =
tests/*: D100, D101, D102, D103
__init__.py: F401
pulser-core/pulser/backends.py: F401
pulser-core/pulser/math/__init__.py: D103
setup.py: D100
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.12"]
with-torch: ["with-torch", "no-torch"]
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
steps:
- name: Check out Pulser
uses: actions/checkout@v4
Expand All @@ -67,8 +68,13 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
extra-packages: pytest
with-torch: ${{ matrix.with-torch }}
- name: Run the unit tests & generate coverage report
if: ${{ matrix.with-torch == 'with-torch' }}
run: pytest --cov --cov-fail-under=100
- name: Run the unit tests without torch installed
if: ${{ matrix.with-torch != 'with-torch' }}
run: pytest --cov
a-corni marked this conversation as resolved.
Show resolved Hide resolved
- name: Test validation with legacy jsonschema
run: |
pip install jsonschema==4.17.3
Expand Down
13 changes: 12 additions & 1 deletion .github/workflows/pulser-setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ inputs:
description: Extra packages to install (give to grep)
required: false
default: ""
with-torch:
description: Whether to include pytorch
required: false
default: "with-torch"
runs:
using: "composite"
steps:
Expand All @@ -17,11 +21,18 @@ runs:
with:
python-version: ${{ inputs.python-version }}
cache: "pip"
- name: Install Pulser
- name: Install Pulser (with torch)
if: ${{ inputs.with-torch == 'with-torch' }}
shell: bash
run: |
python -m pip install --upgrade pip
make dev-install
- name: Install Pulser (without torch)
if: ${{ inputs.with-torch != 'with-torch' }}
shell: bash
run: |
python -m pip install --upgrade pip
make dev-install-no-torch
- name: Install extra packages from the dev requirements
if: "${{ inputs.extra-packages != '' }}"
shell: bash
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
# Python 3.8 and 3.9 does not run on macos-latest (14)
# Uses macos-13 for 3.8 and 3.9 and macos-latest for >=3.10
os: [ubuntu-latest, macos-13, macos-latest, windows-latest]
with-torch: ["with-torch", "no-torch"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
exclude:
- os: macos-latest
Expand All @@ -38,5 +39,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
extra-packages: pytest
with-torch: ${{ matrix.with-torch }}
- name: Run the unit tests & generate coverage report
run: pytest --cov --cov-fail-under=100
run: pytest --cov
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
.PHONY: dev-install
dev-install: dev-install-core dev-install-simulation dev-install-pasqal

.PHONY: dev-install-no-torch
dev-install-no-torch: dev-install-core-no-torch dev-install-simulation dev-install-pasqal

.PHONY: dev-install-core
dev-install-core:
pip install -e ./pulser-core[torch]

.PHONY: dev-install-core-no-torch
dev-install-core-no-torch:
pip install -e ./pulser-core

.PHONY: dev-install-simulation
Expand Down
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ If you wish to install only the core ``pulser`` features, you can instead run:
pip install pulser-core
```

### Including PyTorch

To include PyTorch in your installation, append the ``[torch]`` suffix to the commands outlined above, i.e.

```bash
pip install pulser[torch]
```

for the standard ``pulser`` distribution with PyTorch, **or**

```bash
pip install pulser-core[torch]
```

for just the core features plus PyTorch support.

### Development install

If you wish to **install the development version of Pulser from source** instead, do the following from within this repository after cloning it:

```bash
Expand Down
42 changes: 24 additions & 18 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import numpy as np
from numpy.typing import ArrayLike
from scipy.fft import fft, fftfreq, ifft

import pulser.math as pm
from pulser.channels.eom import MODBW_TO_TR, BaseEOM
from pulser.json.utils import get_dataclass_defaults, obj_to_dict
from pulser.pulse import Pulse
Expand Down Expand Up @@ -420,22 +420,24 @@ def validate_pulse(self, pulse: Pulse) -> None:
f"'pulse' must be of type Pulse, not of type {type(pulse)}."
)

if self.max_amp is not None and np.any(
pulse.amplitude.samples > self.max_amp
):
amp_samples_np = pulse.amplitude.samples.as_array(detach=True)
if self.max_amp is not None and np.any(amp_samples_np > self.max_amp):
raise ValueError(
"The pulse's amplitude goes over the maximum "
"value allowed for the chosen channel."
)
if self.max_abs_detuning is not None and np.any(
np.round(np.abs(pulse.detuning.samples), decimals=6)
np.round(
np.abs(pulse.detuning.samples.as_array(detach=True)),
decimals=6,
)
> self.max_abs_detuning
):
raise ValueError(
"The pulse's detuning values go out of the range "
"allowed for the chosen channel."
)
avg_amp = np.average(pulse.amplitude.samples)
avg_amp = np.average(amp_samples_np)
if 0 < avg_amp < self.min_avg_amp:
raise ValueError(
"The pulse's average amplitude is below the chosen "
Expand All @@ -453,10 +455,10 @@ def _modulation_padding(self) -> int:

def modulate(
self,
input_samples: np.ndarray,
input_samples: ArrayLike,
a-corni marked this conversation as resolved.
Show resolved Hide resolved
keep_ends: bool = False,
eom: bool = False,
) -> np.ndarray:
) -> pm.AbstractArray:
"""Modulates the input according to the channel's modulation bandwidth.

Args:
Expand All @@ -482,17 +484,17 @@ def modulate(
" 'Channel.modulate()' returns the 'input_samples' unchanged.",
stacklevel=2,
)
return input_samples
return pm.AbstractArray(input_samples)
else:
mod_bandwidth = self.mod_bandwidth
mod_padding = self._modulation_padding

if keep_ends:
samples = np.pad(
samples = pm.pad(
input_samples, mod_padding + self.rise_time, mode="edge"
)
else:
samples = np.pad(input_samples, mod_padding)
samples = pm.pad(input_samples, mod_padding)
mod_samples = self.apply_modulation(samples, mod_bandwidth)
if keep_ends:
# Cut off the extra ends
Expand All @@ -501,8 +503,8 @@ def modulate(

@staticmethod
def apply_modulation(
input_samples: np.ndarray, mod_bandwidth: float
) -> np.ndarray:
input_samples: ArrayLike, mod_bandwidth: float
a-corni marked this conversation as resolved.
Show resolved Hide resolved
) -> pm.AbstractArray:
"""Applies the modulation transfer fuction to the input samples.

Note:
Expand All @@ -516,10 +518,11 @@ def apply_modulation(
"""
# The cutoff frequency (fc) and the modulation transfer function
# are defined in https://tinyurl.com/bdeumc8k
input_samples = pm.AbstractArray(input_samples)
fc = mod_bandwidth * 1e-3 / np.sqrt(np.log(2))
freqs = fftfreq(input_samples.size)
modulation = np.exp(-(freqs**2) / fc**2)
return cast(np.ndarray, ifft(fft(input_samples) * modulation).real)
freqs = pm.fftfreq(input_samples.size)
modulation = pm.exp(-(freqs**2) / fc**2)
return pm.ifft(pm.fft(input_samples) * modulation).real

def calc_modulation_buffer(
self,
a-corni marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -553,8 +556,11 @@ def calc_modulation_buffer(
f"The channel {self} doesn't have a modulation bandwidth."
)
tr = self.rise_time
samples = np.pad(input_samples, tr)
diffs = np.abs(samples - mod_samples) <= max_allowed_diff
samples = pm.pad(input_samples, tr)
diffs = (
abs(samples - mod_samples).as_array(detach=True)
a-corni marked this conversation as resolved.
Show resolved Hide resolved
<= max_allowed_diff
)
try:
# Finds the last index in the start buffer that's below the max
# allowed diff. Considers that the waveform could start at the next
Expand Down
5 changes: 4 additions & 1 deletion pulser-core/pulser/channels/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np

import pulser.math as pm
from pulser.channels.base_channel import Channel
from pulser.json.utils import get_dataclass_defaults
from pulser.pulse import Pulse
Expand Down Expand Up @@ -112,7 +113,9 @@ def validate_pulse(
(defaults to a detuning map with weight 1.0).
"""
super().validate_pulse(pulse)
round_detuning = np.round(pulse.detuning.samples, decimals=6)
round_detuning = pm.round(pulse.detuning.samples, 6).as_array(
detach=True
)
# Check that detuning is negative
if np.any(round_detuning > 0):
raise ValueError("The detuning in a DMM must not be positive.")
Expand Down
57 changes: 32 additions & 25 deletions pulser-core/pulser/channels/eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np

import pulser.math as pm
from pulser.json.utils import get_dataclass_defaults, obj_to_dict

# Conversion factor from modulation bandwith to rise time
Expand Down Expand Up @@ -210,30 +211,30 @@ def _switching_beams_combos(self) -> list[tuple[RydbergBeam, ...]]:
@overload
def calculate_detuning_off(
self,
amp_on: float,
detuning_on: float,
amp_on: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
optimal_detuning_off: float,
return_switching_beams: Literal[False],
) -> float:
) -> pm.AbstractArray:
pass

@overload
def calculate_detuning_off(
self,
amp_on: float,
detuning_on: float,
amp_on: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
optimal_detuning_off: float,
return_switching_beams: Literal[True],
) -> tuple[float, tuple[RydbergBeam, ...]]:
) -> tuple[pm.AbstractArray, tuple[RydbergBeam, ...]]:
pass

def calculate_detuning_off(
self,
amp_on: float,
detuning_on: float,
amp_on: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
optimal_detuning_off: float,
return_switching_beams: bool = False,
) -> float | tuple[float, tuple[RydbergBeam, ...]]:
) -> pm.AbstractArray | tuple[pm.AbstractArray, tuple[RydbergBeam, ...]]:
"""Calculates the detuning when the amplitude is off in EOM mode.

Args:
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -246,17 +247,19 @@ def calculate_detuning_off(
on and off.
"""
off_options = self.detuning_off_options(amp_on, detuning_on)
closest_option = np.abs(off_options - optimal_detuning_off).argmin()
best_det_off = cast(float, off_options[closest_option])
closest_option = np.abs(
off_options.as_array(detach=True) - optimal_detuning_off
).argmin()
best_det_off = off_options[closest_option]
if not return_switching_beams:
return best_det_off
return best_det_off, self._switching_beams_combos[closest_option]

def detuning_off_options(
self,
rabi_frequency: float,
detuning_on: float,
) -> np.ndarray:
rabi_frequency: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
) -> pm.AbstractArray:
"""Calculates the possible detuning values when the amplitude is off.

Args:
Expand All @@ -267,11 +270,14 @@ def detuning_off_options(
Returns:
The possible detuning values when in between pulses.
"""
rabi_frequency = pm.AbstractArray(rabi_frequency)
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
# detuning = offset + lightshift

# offset takes into account the lightshift when both beams are on
# which is not zero when the Rabi freq of both beams is not equal
offset = detuning_on - self._lightshift(rabi_frequency, *RydbergBeam)
offset = pm.AbstractArray(detuning_on) - self._lightshift(
rabi_frequency, *RydbergBeam
)
all_beams: set[RydbergBeam] = set(RydbergBeam)
lightshifts = []
for beams_off in self._switching_beams_combos:
Expand All @@ -280,25 +286,26 @@ def detuning_off_options(
lightshifts.append(self._lightshift(rabi_frequency, *beams_on))

# We sum the offset to all lightshifts to get the effective detuning
return np.array(lightshifts) + offset
return pm.flatten(pm.vstack(lightshifts)) + offset

def _lightshift(
self, rabi_frequency: float, *beams_on: RydbergBeam
) -> float:
self, rabi_frequency: pm.AbstractArray, *beams_on: RydbergBeam
) -> pm.AbstractArray:
# lightshift = (rabi_blue**2 - rabi_red**2) / 4 * int_detuning
rabi_freqs = self._rabi_freq_per_beam(rabi_frequency)
bias = {
RydbergBeam.RED: -self.red_shift_coeff,
RydbergBeam.BLUE: self.blue_shift_coeff,
}
# beam off -> beam_rabi_freq = 0
return sum(bias[beam] * rabi_freqs[beam] ** 2 for beam in beams_on) / (
4 * self.intermediate_detuning
return pm.AbstractArray(
sum(bias[beam] * rabi_freqs[beam] ** 2 for beam in beams_on)
/ (4 * self.intermediate_detuning)
)

def _rabi_freq_per_beam(
self, rabi_frequency: float
) -> dict[RydbergBeam, float]:
self, rabi_frequency: pm.AbstractArray
) -> dict[RydbergBeam, pm.AbstractArray]:
shift_factor = np.sqrt(
self.red_shift_coeff / self.blue_shift_coeff
if self.limiting_beam == RydbergBeam.RED
Expand All @@ -315,14 +322,14 @@ def _rabi_freq_per_beam(
if rabi_frequency <= limit_rabi_freq:
base_amp_squared = 2 * rabi_frequency * self.intermediate_detuning
return {
self.limiting_beam: np.sqrt(base_amp_squared / shift_factor),
~self.limiting_beam: np.sqrt(base_amp_squared * shift_factor),
self.limiting_beam: pm.sqrt(base_amp_squared / shift_factor),
~self.limiting_beam: pm.sqrt(base_amp_squared * shift_factor),
}

# The limiting beam is at its maximum amplitude while the other
# has the necessary amplitude to reach the desired effective rabi freq
return {
self.limiting_beam: self.max_limiting_amp,
self.limiting_beam: pm.AbstractArray(self.max_limiting_amp),
~self.limiting_beam: 2
* self.intermediate_detuning
* rabi_frequency
Expand Down
Loading