diff --git a/pulser-core/pulser/json/abstract_repr/deserializer.py b/pulser-core/pulser/json/abstract_repr/deserializer.py index ff6d1b78..13093081 100644 --- a/pulser-core/pulser/json/abstract_repr/deserializer.py +++ b/pulser-core/pulser/json/abstract_repr/deserializer.py @@ -291,6 +291,16 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None: ), correct_phase_drift=op.get("correct_phase_drift", False), ) + elif op["op"] == "modify_eom_setpoint": + seq.modify_eom_setpoint( + channel=op["channel"], + amp_on=_deserialize_parameter(op["amp_on"], vars), + detuning_on=_deserialize_parameter(op["detuning_on"], vars), + optimal_detuning_off=_deserialize_parameter( + op["optimal_detuning_off"], vars + ), + correct_phase_drift=op["correct_phase_drift"], + ) elif op["op"] == "add_eom_pulse": seq.add_eom_pulse( channel=op["channel"], diff --git a/pulser-core/pulser/json/abstract_repr/schemas/sequence-schema.json b/pulser-core/pulser/json/abstract_repr/schemas/sequence-schema.json index 48838461..2c0b8afd 100644 --- a/pulser-core/pulser/json/abstract_repr/schemas/sequence-schema.json +++ b/pulser-core/pulser/json/abstract_repr/schemas/sequence-schema.json @@ -687,6 +687,44 @@ ], "type": "object" }, + "OpModifyEOM": { + "additionalProperties": false, + "properties": { + "amp_on": { + "$ref": "#/definitions/ParametrizedNum", + "description": "The new amplitude of the EOM pulses (in rad/µs)." + }, + "channel": { + "$ref": "#/definitions/ChannelName", + "description": "The name of the channel currently in EOM mode." + }, + "correct_phase_drift": { + "description": "Performs a phase shift to correct for the phase drift incurred while modifying the EOM setpoint.", + "type": "boolean" + }, + "detuning_on": { + "$ref": "#/definitions/ParametrizedNum", + "description": "The new detuning of the EOM pulses (in rad/µs)." + }, + "op": { + "const": "modify_eom_setpoint", + "type": "string" + }, + "optimal_detuning_off": { + "$ref": "#/definitions/ParametrizedNum", + "description": "The new optimal value of detuning (in rad/µs) when there is no pulse being played. It will choose the closest value among the existing options." + } + }, + "required": [ + "op", + "channel", + "amp_on", + "detuning_on", + "optimal_detuning_off", + "correct_phase_drift" + ], + "type": "object" + }, "OpPhaseShift": { "additionalProperties": false, "description": "Adds a separate phase shift to atoms. If possible, OpPulse phase and post_phase_shift are preferred.", @@ -865,6 +903,9 @@ { "$ref": "#/definitions/OpEnableEOM" }, + { + "$ref": "#/definitions/OpModifyEOM" + }, { "$ref": "#/definitions/OpDisableEOM" }, diff --git a/pulser-core/pulser/json/abstract_repr/serializer.py b/pulser-core/pulser/json/abstract_repr/serializer.py index 925bc618..6b5ab3bc 100644 --- a/pulser-core/pulser/json/abstract_repr/serializer.py +++ b/pulser-core/pulser/json/abstract_repr/serializer.py @@ -358,6 +358,18 @@ def remove_kwarg_if_default( data, call.name, "correct_phase_drift" ) operations.append({"op": "enable_eom_mode", **data}) + elif call.name == "modify_eom_setpoint": + data = get_all_args( + ( + "channel", + "amp_on", + "detuning_on", + "optimal_detuning_off", + "correct_phase_drift", + ), + call, + ) + operations.append({"op": "modify_eom_setpoint", **data}) elif call.name == "add_eom_pulse": data = get_all_args( ( diff --git a/pulser-core/pulser/sequence/_schedule.py b/pulser-core/pulser/sequence/_schedule.py index 8f2847ba..3384c63f 100644 --- a/pulser-core/pulser/sequence/_schedule.py +++ b/pulser-core/pulser/sequence/_schedule.py @@ -341,12 +341,14 @@ def enable_eom( detuning_off: float, switching_beams: tuple[RydbergBeam, ...] = (), _skip_buffer: bool = False, + _skip_wait_for_fall: bool = False, ) -> None: channel_obj = self[channel_id].channel_obj # Adds a buffer unless the channel is empty or _skip_buffer = True if not _skip_buffer and self.get_duration(channel_id): - # Wait for the last pulse to ramp down (if needed) - self.wait_for_fall(channel_id) + if not _skip_wait_for_fall: + # Wait for the last pulse to ramp down (if needed) + self.wait_for_fall(channel_id) eom_buffer_time = self[channel_id].adjust_duration( channel_obj._eom_buffer_time ) diff --git a/pulser-core/pulser/sequence/sequence.py b/pulser-core/pulser/sequence/sequence.py index 0f9e6efa..3869c123 100644 --- a/pulser-core/pulser/sequence/sequence.py +++ b/pulser-core/pulser/sequence/sequence.py @@ -43,7 +43,7 @@ import pulser.sequence._decorators as seq_decorators from pulser.channels.base_channel import Channel, States, get_states_from_bases from pulser.channels.dmm import DMM, _dmm_id_from_name, _get_dmm_name -from pulser.channels.eom import RydbergEOM +from pulser.channels.eom import RydbergBeam, RydbergEOM from pulser.devices._device_datacls import BaseDevice from pulser.json.abstract_repr.deserializer import ( deserialize_abstract_sequence, @@ -1139,54 +1139,35 @@ def enable_eom_mode( raise RuntimeError( f"The '{channel}' channel is already in EOM mode." ) + channel_obj = self.declared_channels[channel] if not channel_obj.supports_eom(): raise TypeError(f"Channel '{channel}' does not have an EOM.") - on_pulse = Pulse.ConstantPulse( - channel_obj.min_duration, amp_on, detuning_on, 0.0 + detuning_off, switching_beams = self._process_eom_parameters( + channel_obj, amp_on, detuning_on, optimal_detuning_off ) - stored_opt_detuning_off = optimal_detuning_off - if not isinstance(on_pulse, Parametrized): - channel_obj.validate_pulse(on_pulse) - amp_on = cast(float, amp_on) - detuning_on = cast(float, detuning_on) - eom_config = cast(RydbergEOM, channel_obj.eom_config) - if not isinstance(optimal_detuning_off, Parametrized): - ( - detuning_off, - switching_beams, - ) = eom_config.calculate_detuning_off( - amp_on, - detuning_on, - optimal_detuning_off, - return_switching_beams=True, - ) - off_pulse = Pulse.ConstantPulse( - channel_obj.min_duration, 0.0, detuning_off, 0.0 - ) - channel_obj.validate_pulse(off_pulse) - # Update optimal_detuning_off to match the chosen detuning_off - # This minimizes the changes to the sequence when the device - # is switched - stored_opt_detuning_off = detuning_off - - if not self.is_parametrized(): - phase_drift_params = _PhaseDriftParams( - drift_rate=-detuning_off, - # enable_eom() calls wait for fall, so the block only - # starts after fall time - ti=self.get_duration(channel, include_fall_time=True), - ) - self._schedule.enable_eom( - channel, amp_on, detuning_on, detuning_off, switching_beams + if not self.is_parametrized(): + detuning_off = cast(float, detuning_off) + phase_drift_params = _PhaseDriftParams( + drift_rate=-detuning_off, + # enable_eom() calls wait for fall, so the block only + # starts after fall time + ti=self.get_duration(channel, include_fall_time=True), + ) + self._schedule.enable_eom( + channel, + cast(float, amp_on), + cast(float, detuning_on), + detuning_off, + switching_beams, + ) + if correct_phase_drift: + buffer_slot = self._last(channel) + drift = phase_drift_params.calc_phase_drift(buffer_slot.tf) + self._phase_shift( + -drift, *buffer_slot.targets, basis=channel_obj.basis ) - if correct_phase_drift: - buffer_slot = self._last(channel) - drift = phase_drift_params.calc_phase_drift(buffer_slot.tf) - self._phase_shift( - -drift, *buffer_slot.targets, basis=channel_obj.basis - ) # Manually store the call to "enable_eom_mode" so that the updated # 'optimal_detuning_off' is stored @@ -1201,7 +1182,7 @@ def enable_eom_mode( channel=channel, amp_on=amp_on, detuning_on=detuning_on, - optimal_detuning_off=stored_opt_detuning_off, + optimal_detuning_off=detuning_off, correct_phase_drift=correct_phase_drift, ), ) @@ -1253,6 +1234,90 @@ def disable_eom_mode( basis=ch_schedule.channel_obj.basis, ) + @seq_decorators.verify_parametrization + @seq_decorators.block_if_measured + def modify_eom_setpoint( + self, + channel: str, + amp_on: Union[float, Parametrized], + detuning_on: Union[float, Parametrized], + optimal_detuning_off: Union[float, Parametrized] = 0.0, + correct_phase_drift: bool = False, + ) -> None: + """Modifies the setpoint of an ongoing EOM mode operation. + + Note: + Modifying the EOM setpoint will automatically enforce a buffer. + The detuning will go to the `detuning_off` value during + this buffer. This buffer will not wait for pulses on other + channels to finish, so calling `Sequence.align()` or + `Sequence.delay()` beforehand is necessary to avoid eventual + conflicts. + + Args: + channel: The name of the channel currently in EOM mode. + amp_on: The new amplitude of the EOM pulses (in rad/µs). + detuning_on: The new detuning of the EOM pulses (in rad/µs). + optimal_detuning_off: The new optimal value of detuning (in rad/µs) + when there is no pulse being played. It will choose the closest + value among the existing options. + correct_phase_drift: Performs a phase shift to correct for the + phase drift incurred while modifying the EOM setpoint. + """ + if not self.is_in_eom_mode(channel): + raise RuntimeError(f"The '{channel}' channel is not in EOM mode.") + + channel_obj = self.declared_channels[channel] + detuning_off, switching_beams = self._process_eom_parameters( + channel_obj, amp_on, detuning_on, optimal_detuning_off + ) + + if not self.is_parametrized(): + detuning_off = cast(float, detuning_off) + self._schedule.disable_eom(channel, _skip_buffer=True) + old_phase_drift_params = self._get_last_eom_pulse_phase_drift( + channel + ) + new_phase_drift_params = _PhaseDriftParams( + drift_rate=-detuning_off, + ti=self.get_duration(channel, include_fall_time=False), + ) + self._schedule.enable_eom( + channel, + cast(float, amp_on), + cast(float, detuning_on), + detuning_off, + switching_beams, + _skip_wait_for_fall=True, + ) + if correct_phase_drift: + buffer_slot = self._last(channel) + drift = old_phase_drift_params.calc_phase_drift( + buffer_slot.ti + ) + new_phase_drift_params.calc_phase_drift(buffer_slot.tf) + self._phase_shift( + -drift, *buffer_slot.targets, basis=channel_obj.basis + ) + + # Manually store the call to "modify_eom_setpoint" so that the updated + # 'optimal_detuning_off' is stored + call_container = ( + self._to_build_calls if self.is_parametrized() else self._calls + ) + call_container.append( + _Call( + "modify_eom_setpoint", + (), + dict( + channel=channel, + amp_on=amp_on, + detuning_on=detuning_on, + optimal_detuning_off=detuning_off, + correct_phase_drift=correct_phase_drift, + ), + ) + ) + @seq_decorators.store @seq_decorators.mark_non_empty @seq_decorators.block_if_measured @@ -2389,6 +2454,43 @@ def _validate_add_protocol(self, protocol: str) -> None: + ", ".join(valid_protocols) ) + def _process_eom_parameters( + self, + channel_obj: Channel, + amp_on: Union[float, Parametrized], + detuning_on: Union[float, Parametrized], + optimal_detuning_off: Union[float, Parametrized], + ) -> tuple[float | Parametrized, tuple[RydbergBeam, ...]]: + on_pulse = Pulse.ConstantPulse( + channel_obj.min_duration, amp_on, detuning_on, 0.0 + ) + stored_opt_detuning_off = optimal_detuning_off + switching_beams: tuple[RydbergBeam, ...] = () + if not isinstance(on_pulse, Parametrized): + channel_obj.validate_pulse(on_pulse) + amp_on = cast(float, amp_on) + detuning_on = cast(float, detuning_on) + eom_config = cast(RydbergEOM, channel_obj.eom_config) + if not isinstance(optimal_detuning_off, Parametrized): + ( + detuning_off, + switching_beams, + ) = eom_config.calculate_detuning_off( + amp_on, + detuning_on, + optimal_detuning_off, + return_switching_beams=True, + ) + off_pulse = Pulse.ConstantPulse( + channel_obj.min_duration, 0.0, detuning_off, 0.0 + ) + channel_obj.validate_pulse(off_pulse) + # Update optimal_detuning_off to match the chosen detuning_off + # This minimizes the changes to the sequence when the device + # is switched + stored_opt_detuning_off = detuning_off + return stored_opt_detuning_off, switching_beams + def _reset_parametrized(self) -> None: """Resets all attributes related to parametrization.""" # Signals the sequence as actively "building" ie not parametrized diff --git a/tests/test_abstract_repr.py b/tests/test_abstract_repr.py index 3cefbb09..f61e9679 100644 --- a/tests/test_abstract_repr.py +++ b/tests/test_abstract_repr.py @@ -935,6 +935,13 @@ def test_eom_mode( "ryd", duration, 0.0, correct_phase_drift=correct_phase_drift ) seq.delay(duration, "ryd", at_rest=delay_at_rest) + seq.modify_eom_setpoint( + "ryd", + amp_on=2.0, + detuning_on=-1.0, + optimal_detuning_off=det_off, + correct_phase_drift=correct_phase_drift, + ) seq.disable_eom_mode("ryd", correct_phase_drift) abstract = json.loads(seq.to_abstract_repr()) @@ -992,6 +999,21 @@ def test_eom_mode( } assert abstract["operations"][3] == { + **{ + "op": "modify_eom_setpoint", + "channel": "ryd", + "amp_on": 2.0, + "detuning_on": -1.0, + "optimal_detuning_off": { + "expression": "index", + "lhs": {"variable": "det_off"}, + "rhs": 0, + }, + "correct_phase_drift": correct_phase_drift, + }, + } + + assert abstract["operations"][4] == { **{ "op": "disable_eom_mode", "channel": "ryd", @@ -1230,7 +1252,11 @@ def _check_roundtrip(serialized_seq: dict[str, Any]): *(op[wf][qty] for qty in wf_args) ) op[wf] = reconstructed_wf._to_abstract_repr() - elif "eom" in op["op"] and not op.get("correct_phase_drift"): + elif ( + "eom" in op["op"] + and not op.get("correct_phase_drift") + and op["op"] != "modify_eom_setpoint" + ): # Remove correct_phase_drift when at default, since the # roundtrip will delete it op.pop("correct_phase_drift", None) @@ -2055,6 +2081,14 @@ def test_deserialize_eom_ops(self, correct_phase_drift, var_detuning_on): "protocol": "no-delay", "correct_phase_drift": correct_phase_drift, }, + { + "op": "modify_eom_setpoint", + "channel": "global", + "amp_on": 1.0, + "detuning_on": detuning_on, + "optimal_detuning_off": -0.5, + "correct_phase_drift": correct_phase_drift or False, + }, { "op": "disable_eom_mode", "channel": "global", @@ -2070,13 +2104,14 @@ def test_deserialize_eom_ops(self, correct_phase_drift, var_detuning_on): ) if correct_phase_drift is None: for op in s["operations"]: - del op["correct_phase_drift"] + if "modify" not in op["op"]: + del op["correct_phase_drift"] seq = Sequence.from_abstract_repr(json.dumps(s)) # init + declare_channel + enable_eom_mode (if not var_detuning_on) assert len(seq._calls) == 3 - var_detuning_on # add_eom_pulse + disable_eom + enable_eom_mode (if var_detuning_on) - assert len(seq._to_build_calls) == 2 + var_detuning_on + assert len(seq._to_build_calls) == 3 + var_detuning_on if var_detuning_on: enable_eom_call = seq._to_build_calls[0] @@ -2108,6 +2143,21 @@ def test_deserialize_eom_ops(self, correct_phase_drift, var_detuning_on): else: assert detuning_on_kwarg == detuning_on + modify_eom_call = seq._to_build_calls[-2] + assert modify_eom_call.name == "modify_eom_setpoint" + modify_eom_kwargs = modify_eom_call.kwargs.copy() + detuning_on_kwarg = modify_eom_kwargs.pop("detuning_on") + assert modify_eom_kwargs == { + "channel": "global", + "amp_on": 1.0, + "optimal_detuning_off": -0.5, + "correct_phase_drift": bool(correct_phase_drift), + } + if var_detuning_on: + assert isinstance(detuning_on_kwarg, VariableItem) + else: + assert detuning_on_kwarg == detuning_on + disable_eom_call = seq._to_build_calls[-1] assert disable_eom_call.name == "disable_eom_mode" assert disable_eom_call.kwargs == { diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 12a7f685..876cd56e 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2408,6 +2408,69 @@ def test_eom_buffer( ) +@pytest.mark.parametrize("correct_phase_drift", [True, False]) +@pytest.mark.parametrize("amp_diff", [0, -0.5, 0.5]) +@pytest.mark.parametrize("det_diff", [0, -5, 10]) +def test_modify_eom_setpoint( + reg, mod_device, amp_diff, det_diff, correct_phase_drift +): + seq = Sequence(reg, mod_device) + seq.declare_channel("ryd", "rydberg_global") + params = seq.declare_variable("params", dtype=float, size=2) + dt = 100 + amp, det_on = params + with pytest.raises( + RuntimeError, match="The 'ryd' channel is not in EOM mode" + ): + seq.modify_eom_setpoint("ryd", amp, det_on) + seq.enable_eom_mode("ryd", amp, det_on) + assert seq.is_in_eom_mode("ryd") + seq.add_eom_pulse("ryd", dt, 0.0) + seq.delay(dt, "ryd") + + new_amp, new_det_on = amp + amp_diff, det_on + det_diff + seq.modify_eom_setpoint( + "ryd", new_amp, new_det_on, correct_phase_drift=correct_phase_drift + ) + assert seq.is_in_eom_mode("ryd") + seq.add_eom_pulse("ryd", dt, 0.0) + seq.delay(dt, "ryd") + + ryd_ch_obj = seq.declared_channels["ryd"] + eom_buffer_dt = ryd_ch_obj._eom_buffer_time + param_vals = [1.0, 0.0] + built_seq = seq.build(params=param_vals) + expected_duration = 4 * dt + eom_buffer_dt + assert built_seq.get_duration() == expected_duration + + amp, det = param_vals + ch_samples = sample(built_seq).channel_samples["ryd"] + expected_amp = np.zeros(expected_duration) + expected_amp[:dt] = amp + expected_amp[-2 * dt : -dt] = amp + amp_diff + np.testing.assert_array_equal(expected_amp, ch_samples.amp) + + det_off = ryd_ch_obj.eom_config.calculate_detuning_off(amp, det, 0.0) + new_det_off = ryd_ch_obj.eom_config.calculate_detuning_off( + amp + amp_diff, det + det_diff, 0.0 + ) + expected_det = np.zeros(expected_duration) + expected_det[:dt] = det + expected_det[dt : 2 * dt] = det_off + expected_det[2 * dt : 2 * dt + eom_buffer_dt] = new_det_off + expected_det[-2 * dt : -dt] = det + det_diff + expected_det[-dt:] = new_det_off + np.testing.assert_array_equal(expected_det, ch_samples.det) + + final_phase = built_seq.current_phase_ref("q0", "ground-rydberg") + if not correct_phase_drift: + assert final_phase == 0.0 + else: + assert final_phase != 0.0 + np.testing.assert_array_equal(ch_samples.phase[: 2 * dt], 0.0) + np.testing.assert_array_equal(ch_samples.phase[-2 * dt :], final_phase) + + def test_max_duration(reg, mod_device): dev_ = dataclasses.replace(mod_device, max_sequence_duration=100) seq = Sequence(reg, dev_)