diff --git a/pulser-core/pulser/sequence/_schedule.py b/pulser-core/pulser/sequence/_schedule.py index 744040ed2..f5207d918 100644 --- a/pulser-core/pulser/sequence/_schedule.py +++ b/pulser-core/pulser/sequence/_schedule.py @@ -396,14 +396,15 @@ def disable_eom(self, channel_id: str, _skip_buffer: bool = False) -> None: else: self.wait_for_fall(channel_id) - def add_pulse( + def make_next_pulse_slot( self, pulse: Pulse, channel: str, phase_barrier_ts: list[int], protocol: str, phase_drift_params: _PhaseDriftParams | None = None, - ) -> None: + block_over_max_duration: bool = False, + ) -> _TimeSlot: def corrected_phase(tf: int) -> pm.AbstractArray: phase_drift = pm.AbstractArray( phase_drift_params.calc_phase_drift(tf) @@ -447,11 +448,10 @@ def corrected_phase(tf: int) -> pm.AbstractArray: delay_duration = max(current_max_t - t0, phase_jump_buffer) if delay_duration > 0: delay_duration = self[channel].adjust_duration(delay_duration) - self.add_delay(delay_duration, channel) ti = t0 + delay_duration tf = ti + pulse.duration - self._check_duration(tf) + self._check_duration(tf, block_over_max_duration) # dataclasses.replace() does not work on Pulse (because init=False) if phase_drift_params is not None: pulse = Pulse( @@ -460,7 +460,29 @@ def corrected_phase(tf: int) -> pm.AbstractArray: phase=corrected_phase(ti), post_phase_shift=pulse.post_phase_shift, ) - self[channel].slots.append(_TimeSlot(pulse, ti, tf, last.targets)) + return _TimeSlot(pulse, ti, tf, last.targets) + + def add_pulse( + self, + pulse: Pulse, + channel: str, + phase_barrier_ts: list[int], + protocol: str, + phase_drift_params: _PhaseDriftParams | None = None, + ) -> None: + last = self[channel][-1] + time_slot = self.make_next_pulse_slot( + pulse, + channel, + phase_barrier_ts, + protocol, + phase_drift_params, + True, + ) + delay_duration = time_slot.ti - last.tf + if delay_duration > 0: + self.add_delay(delay_duration, channel) + self[channel].slots.append(time_slot) def add_delay(self, duration: int, channel: str) -> None: last = self[channel][-1] @@ -557,9 +579,14 @@ def _get_last_pulse_phase(self, channel: str) -> pm.AbstractArray: phase = pm.AbstractArray(0.0) return phase - def _check_duration(self, t: int) -> None: + def _check_duration( + self, t: int, block_over_max_duration: bool = True + ) -> None: if self.max_duration is not None and t > self.max_duration: - raise RuntimeError( + msg = ( "The sequence's duration exceeded the maximum duration allowed" f" by the device ({self.max_duration} ns)." ) + if block_over_max_duration: + raise RuntimeError(msg) + warnings.warn(msg, UserWarning) diff --git a/pulser-core/pulser/sequence/sequence.py b/pulser-core/pulser/sequence/sequence.py index 6287dbd8b..11334dccd 100644 --- a/pulser-core/pulser/sequence/sequence.py +++ b/pulser-core/pulser/sequence/sequence.py @@ -1381,6 +1381,98 @@ def delay( """ self._delay(duration, channel, at_rest) + def estimate_added_delay( + self, + pulse: Union[Pulse, Parametrized], + channel: str, + protocol: PROTOCOLS = "min-delay", + ) -> int: + """Delay that will be added before the pulse when added to a channel. + + When adding a pulse to a channel of the Sequence, a delay can be added + to account for the modulation bandwidth of the channel or the protocol + chosen. This method estimates the delay that will be added before the + pulse if this pulse was added to this channel with this protocol. It + works even if the channel is in EOM mode, but to be appropriate, the + Pulse should be a ConstantPulse with amplitude and detuning + respectively the rabi_freq and detuning_on of the EOM block. + + Args: + pulse: The pulse object to add to the channel. + channel: The channel's name provided when declared. + protocol: Stipulates how to deal with + eventual conflicts with other channels, specifically in terms + of having multiple channels act on the same target + simultaneously. + + - ``'min-delay'``: Before adding the pulse, introduces the + smallest possible delay that avoids all exisiting conflicts. + - ``'no-delay'``: Adds the pulse to the channel, regardless of + existing conflicts. + - ``'wait-for-all'``: Before adding the pulse, adds a delay + that idles the channel until the end of the other channels' + latest pulse. + + Returns: + The delay that would be added before the pulse. + """ + self._validate_channel( + channel, + block_if_slm=channel.startswith("dmm_"), + ) + self._validate_add_protocol(protocol) + if self.is_parametrized() or isinstance(pulse, Parametrized): + raise ValueError( + "Can't compute the delay to add before a pulse if sequence or" + "pulse is parametrized." + ) + if self.is_in_eom_mode(channel): + eom_settings = self._schedule[channel].eom_blocks[-1] + if np.any(pulse.amplitude.samples != eom_settings.rabi_freq): + warnings.warn( + f"Channel {channel} is in EOM mode, the amplitude of the " + "pulse will be constant and equal to " + f"{eom_settings.rabi_freq}.", + UserWarning, + ) + if np.any(pulse.detuning.samples != eom_settings.detuning_on): + warnings.warn( + f"Channel {channel} is in EOM mode, the detuning of the " + "pulse will be constant and equal to " + f"{eom_settings.detuning_on}.", + UserWarning, + ) + channel_obj = self._schedule[channel].channel_obj + last = self._last(channel) + basis = channel_obj.basis + + ph_refs = { + self._basis_ref[basis][q].phase.last_phase for q in last.targets + } + if isinstance(channel_obj, DMM): + phase_ref = None + elif len(ph_refs) != 1: + raise ValueError( + "Cannot do a multiple-target pulse on qubits with different " + "phase references for the same basis." + ) + else: + phase_ref = ph_refs.pop() + + pulse = self._validate_and_adjust_pulse(pulse, channel, phase_ref) + + phase_barriers = [ + self._basis_ref[basis][q].phase.last_time for q in last.targets + ] + next_time_slot = self._schedule.make_next_pulse_slot( + pulse, + channel, + phase_barriers, + protocol, + # phase_drift_params does not impact delay between pulses + ) + return next_time_slot.ti - last.tf + @seq_decorators.store @seq_decorators.block_if_measured def measure(self, basis: str = "ground-rydberg") -> None: diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 6254ea50a..a871ab770 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1729,6 +1729,97 @@ def test_sequence(reg, device, patch_plt_show): assert str(seq) == str(seq_) +@pytest.mark.parametrize("eom", [False, True]) +def test_estimate_added_delay(eom): + reg = Register.square(2, 5) + seq = Sequence(reg, AnalogDevice) + pulse_0 = Pulse.ConstantPulse(100, 1, 0, 0) + pulse_pi_2 = Pulse.ConstantPulse(100, 1, 0, np.pi / 2) + + with pytest.raises( + ValueError, match="Use the name of a declared channel." + ): + seq.estimate_added_delay(pulse_0, "ising", "min-delay") + seq.declare_channel("ising", "rydberg_global") + ising_obj = seq.declared_channels["ising"] + if eom: + seq.enable_eom_mode("ising", 1, 0) + with pytest.warns( + UserWarning, + match="Channel ising is in EOM mode, the amplitude", + ): + assert ( + seq.estimate_added_delay( + Pulse.ConstantPulse(100, 2, 0, 0), "ising" + ) + == 0 + ) + with pytest.warns( + UserWarning, + match="Channel ising is in EOM mode, the detuning", + ): + assert ( + seq.estimate_added_delay( + Pulse.ConstantPulse(100, 1, 1, 0), "ising" + ) + == 0 + ) + assert seq.estimate_added_delay(pulse_0, "ising", "min-delay") == 0 + seq._add(pulse_0, "ising", "min-delay") + first_pulse = seq._last("ising") + assert first_pulse.ti == 0 + delay = pulse_0.fall_time(ising_obj, eom) + ising_obj.phase_jump_time + assert seq.estimate_added_delay(pulse_pi_2, "ising") == delay + seq._add(pulse_pi_2, "ising", "min-delay") + second_pulse = seq._last("ising") + assert second_pulse.ti - first_pulse.tf == delay + assert seq.estimate_added_delay(pulse_0, "ising") == delay + seq.delay(100, "ising") + assert seq.estimate_added_delay(pulse_0, "ising") == delay - 100 + with pytest.warns( + UserWarning, + match="The sequence's duration exceeded the maximum duration", + ): + seq.estimate_added_delay( + pulser.Pulse.ConstantPulse(4000, 1, 0, np.pi), "ising" + ) + var = seq.declare_variable("var", dtype=int) + with pytest.raises( + ValueError, match="Can't compute the delay to add before a pulse" + ): + seq.estimate_added_delay(Pulse.ConstantPulse(var, 1, 0, 0), "ising") + # We shift the phase of just one qubit, which blocks addition + # of new pulses on this basis + seq.phase_shift(1.0, 0, basis="ground-rydberg") + with pytest.raises( + ValueError, + match="Cannot do a multiple-target pulse on qubits with different", + ): + seq.estimate_added_delay(pulse_0, "ising") + + +def test_estimate_added_delay_dmm(): + pulse_0 = Pulse.ConstantPulse(100, 1, 0, 0) + det_pulse = Pulse.ConstantPulse(100, 0, -1, 0) + seq = Sequence(Register.square(2, 5), DigitalAnalogDevice) + seq.declare_channel("ising", "rydberg_global") + seq.config_slm_mask([0, 1]) + with pytest.raises( + ValueError, match="You should add a Pulse to a Global Channel" + ): + seq.estimate_added_delay(det_pulse, "dmm_0") + seq.add(pulse_0, "ising") + assert seq.estimate_added_delay(det_pulse, "dmm_0") == 0 + with pytest.raises( + ValueError, match="The detuning in a DMM must not be positive." + ): + seq.estimate_added_delay(Pulse.ConstantPulse(100, 0, 1, 0), "dmm_0") + with pytest.raises( + ValueError, match="The pulse's amplitude goes over the maximum" + ): + seq.estimate_added_delay(pulse_0, "dmm_0") + + @pytest.mark.parametrize("qubit_ids", [["q0", "q1", "q2"], [0, 1, 2]]) def test_config_slm_mask(qubit_ids, device, det_map): reg: Register | MappableRegister