Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract representation support for new DMM features #568

Merged
merged 8 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,10 @@ def type_check(
for dmm_obj in self.dmm_objects:
type_check("All DMM channels", DMM, value_override=dmm_obj)

# TODO: Check that device has dmm objects if it supports SLM mask
# once DMM is supported for serialization
# if self.supports_slm_mask and not self.dmm_objects:
# raise ValueError(
# "One DMM object should be defined to support SLM mask."
# )
if self.supports_slm_mask and not self.dmm_objects:
raise ValueError(
"One DMM object should be defined to support SLM mask."
)

if self.channel_ids is not None:
if not (
Expand Down Expand Up @@ -455,6 +453,9 @@ def _to_abstract_repr(self) -> dict[str, Any]:
for p in ALWAYS_OPTIONAL_PARAMS:
if params[p] == defaults[p]:
params.pop(p, None)
# Delete parameters of PARAMS_WITH_ABSTR_REPR in params
for p in PARAMS_WITH_ABSTR_REPR:
params.pop(p, None)
ch_list = []
for ch_name, ch_obj in self.channels.items():
ch_list.append(ch_obj._to_abstract_repr(ch_name))
Expand All @@ -463,12 +464,8 @@ def _to_abstract_repr(self) -> dict[str, Any]:
dmm_list = []
for dmm_name, dmm_obj in self.dmm_channels.items():
dmm_list.append(dmm_obj._to_abstract_repr(dmm_name))
# Add dmm channels if different than default
if "dmm_objects" in params:
params["dmm_channels"] = dmm_list
# Delete parameters of PARAMS_WITH_ABSTR_REPR in params
for p in PARAMS_WITH_ABSTR_REPR:
params.pop(p, None)
if dmm_list:
params["dmm_objects"] = dmm_list
return params

def to_abstract_repr(self) -> str:
Expand Down Expand Up @@ -668,6 +665,8 @@ class VirtualDevice(BaseDevice):
max_atom_num: int | None = None
max_radial_distance: int | None = None
supports_slm_mask: bool = True
# Needed to support SLM mask by default
dmm_objects: tuple[DMM, ...] = (DMM(),)
reusable_channels: bool = True

@property
Expand Down
19 changes: 9 additions & 10 deletions pulser-core/pulser/devices/_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Definitions of real devices."""
import numpy as np

from pulser.channels import Raman, Rydberg
from pulser.channels import DMM, Raman, Rydberg
from pulser.channels.eom import RydbergBeam, RydbergEOM
from pulser.devices._device_datacls import Device
from pulser.register.special_layouts import TriangularLatticeLayout
Expand Down Expand Up @@ -56,15 +56,14 @@
max_duration=2**26,
),
),
# TODO: Add DMM once it is supported for serialization
# dmm_objects=(
# DMM(
# clock_period=4,
# min_duration=16,
# max_duration=2**26,
# bottom_detuning=-20,
# ),
# ),
dmm_objects=(
DMM(
clock_period=4,
min_duration=16,
max_duration=2**26,
bottom_detuning=-20,
),
),
)

IroiseMVP = Device(
Expand Down
4 changes: 2 additions & 2 deletions pulser-core/pulser/devices/_mock_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pulser.channels import Microwave, Raman, Rydberg
from pulser.channels import DMM, Microwave, Raman, Rydberg
from pulser.devices._device_datacls import VirtualDevice

MockDevice = VirtualDevice(
Expand All @@ -31,5 +31,5 @@
Raman.Local(None, None, max_duration=None),
Microwave.Global(None, None, max_duration=None),
),
# TODO: Add DMM once it is supported for serialization
dmm_objects=(DMM(),),
)
38 changes: 33 additions & 5 deletions pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pulser
import pulser.devices as devices
from pulser.channels import Microwave, Raman, Rydberg
from pulser.channels import DMM, Microwave, Raman, Rydberg
from pulser.channels.base_channel import Channel
from pulser.channels.eom import (
OPTIONAL_ABSTR_EOM_FIELDS,
Expand All @@ -44,6 +44,7 @@
from pulser.register.mappable_reg import MappableRegister
from pulser.register.register import Register
from pulser.register.register_layout import RegisterLayout
from pulser.register.weight_maps import DetuningMap
from pulser.waveforms import (
BlackmanWaveform,
CompositeWaveform,
Expand Down Expand Up @@ -276,14 +277,25 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
)
elif op["op"] == "disable_eom_mode":
seq.disable_eom_mode(channel=op["channel"])
elif op["op"] == "config_slm_mask":
seq.config_slm_mask(qubits=op["qubits"], dmm_id=op["dmm_id"])
elif op["op"] == "modulate_det_map":
seq.modulate_det_map(
waveform=_deserialize_waveform(op["waveform"], vars),
dmm_name=op["dmm_name"],
protocol=op["protocol"],
)


def _deserialize_channel(obj: dict[str, Any]) -> Channel:
params: dict[str, Any] = {}
channel_cls: Type[Channel]
if obj["basis"] == "ground-rydberg":
channel_cls = Rydberg
params["eom_config"] = None
if "bottom_detuning" in obj:
channel_cls = DMM
else:
channel_cls = Rydberg
params["eom_config"] = None
if obj["eom_config"] is not None:
data = obj["eom_config"]
try:
Expand Down Expand Up @@ -347,9 +359,9 @@ def _deserialize_device_object(obj: dict[str, Any]) -> Device | VirtualDevice:
params: dict[str, Any] = dict(
channel_ids=tuple(ch_ids), channel_objects=tuple(ch_objs)
)
if "dmm_channels" in obj:
if "dmm_objects" in obj:
params["dmm_objects"] = tuple(
_deserialize_channel(dmm_ch) for dmm_ch in obj["dmm_channels"]
_deserialize_channel(dmm_ch) for dmm_ch in obj["dmm_objects"]
)
device_fields = dataclasses.fields(device_cls)
device_defaults = get_dataclass_defaults(device_fields)
Expand Down Expand Up @@ -428,8 +440,24 @@ def deserialize_abstract_sequence(obj_str: str) -> Sequence:

# SLM Mask
if "slm_mask_targets" in obj:
# This is kept for backwards compatibility
seq.config_slm_mask(obj["slm_mask_targets"])

# Detuning Map configuration
if "dmm_channels" in obj:
for dmm_id, ser_det_map in obj["dmm_channels"]:
trap_coords = []
weights = []
for trap in ser_det_map["traps"]:
trap_coords.append((trap["x"], trap["y"]))
weights.append(trap["weight"])
det_map = DetuningMap(
trap_coordinates=trap_coords,
weights=weights,
slug=ser_det_map.get("slug"),
)
seq.config_detuning_map(detuning_map=det_map, dmm_id=dmm_id)

# Variables
vars = {}
for name, desc in obj["variables"].items():
Expand Down
Loading