diff --git a/docs/explanations/decisions/0003-make-devices-factory.md b/docs/explanations/decisions/0003-make-devices-factory.md new file mode 100644 index 0000000000..3292453bb1 --- /dev/null +++ b/docs/explanations/decisions/0003-make-devices-factory.md @@ -0,0 +1,28 @@ +# 3. Add device factory decorator with lazy connect support + +Date: 2024-04-26 + +## Status + +Accepted + +## Context + +Device instances should be capable of being created without necessarily connecting, so long as they are connected prior to being utilised to collect data. The current method puts requirements on the init method of device classes, and does not expose all options for connecting to ophyd-async devices. + +## Decision + +DAQ members led us to this proposal: + +- ophyd-async: make Device.connect(mock, timeout, force=False) idempotent +- ophyd-async: make ensure_connected(\*devices) plan stub +- dodal: make device_factory() decorator that may construct, name, cache and connect a device +- dodal: collect_factories() returns all device factories +- blueapi: call collect_factories(), instantiate and connect Devices appropriately, log those that fail +- blueapi: when plan is called, run ensure_connected on all plan args and defaults that are Devices + +We can then iterate on this if the parallel connect causes a broadcast storm. We could also in future add a monitor to a heartbeat PV per device in Device.connect so that it would reconnect next time it was called. + +## Consequences + +Beamlines will be converted to use the decorator, and default arguments to plans should be replaced with a non-eagerly connecting call to the initializer controlling device. diff --git a/src/dodal/beamlines/i22.py b/src/dodal/beamlines/i22.py index dde357874a..fb7ef57135 100644 --- a/src/dodal/beamlines/i22.py +++ b/src/dodal/beamlines/i22.py @@ -5,12 +5,12 @@ from ophyd_async.fastcs.panda import HDFPanda from dodal.common.beamlines.beamline_utils import ( - device_instantiation, + device_factory, get_path_provider, set_path_provider, ) from dodal.common.beamlines.beamline_utils import set_beamline as set_utils_beamline -from dodal.common.beamlines.device_helpers import numbered_slits +from dodal.common.beamlines.device_helpers import HDF5_PREFIX from dodal.common.crystal_metadata import ( MaterialsEnum, make_crystal_metadata_from_material, @@ -27,9 +27,10 @@ from dodal.devices.undulator import Undulator from dodal.devices.watsonmarlow323_pump import WatsonMarlow323Pump from dodal.log import set_beamline as set_log_beamline -from dodal.utils import BeamlinePrefix, get_beamline_name, skip_device +from dodal.utils import BeamlinePrefix, get_beamline_name BL = get_beamline_name("i22") +PREFIX = BeamlinePrefix(BL) set_log_beamline(BL) set_utils_beamline(BL) @@ -47,237 +48,141 @@ ) -def saxs( - wait_for_connection: bool = True, fake_with_ophyd_sim: bool = False -) -> PilatusDetector: - return device_instantiation( - NXSasPilatus, - "saxs", - "-EA-PILAT-01:", - wait_for_connection, - fake_with_ophyd_sim, - drv_suffix="CAM:", - hdf_suffix="HDF5:", - metadata_holder=NXSasMetadataHolder( - x_pixel_size=(1.72e-1, "mm"), - y_pixel_size=(1.72e-1, "mm"), - description="Dectris Pilatus3 2M", - type="Photon Counting Hybrid Pixel", - sensor_material="silicon", - sensor_thickness=(0.45, "mm"), - distance=(4711.833684146172, "mm"), - ), +@device_factory() +def saxs() -> PilatusDetector: + metadata_holder = NXSasMetadataHolder( + x_pixel_size=(1.72e-1, "mm"), + y_pixel_size=(1.72e-1, "mm"), + description="Dectris Pilatus3 2M", + type="Photon Counting Hybrid Pixel", + sensor_material="silicon", + sensor_thickness=(0.45, "mm"), + distance=(4711.833684146172, "mm"), + ) + return NXSasPilatus( + prefix=f"{PREFIX.beamline_prefix}-EA-PILAT-01:", path_provider=get_path_provider(), + drv_suffix="CAM:", + hdf_suffix=HDF5_PREFIX, + metadata_holder=metadata_holder, ) -def synchrotron( - wait_for_connection: bool = True, fake_with_ophyd_sim: bool = False -) -> Synchrotron: - return device_instantiation( - Synchrotron, - "synchrotron", - "", - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory() +def synchrotron() -> Synchrotron: + return Synchrotron() -def waxs( - wait_for_connection: bool = True, fake_with_ophyd_sim: bool = False -) -> PilatusDetector: - return device_instantiation( - NXSasPilatus, - "waxs", - "-EA-PILAT-03:", - wait_for_connection, - fake_with_ophyd_sim, - drv_suffix="CAM:", - hdf_suffix="HDF5:", - metadata_holder=NXSasMetadataHolder( - x_pixel_size=(1.72e-1, "mm"), - y_pixel_size=(1.72e-1, "mm"), - description="Dectris Pilatus3 2M", - type="Photon Counting Hybrid Pixel", - sensor_material="silicon", - sensor_thickness=(0.45, "mm"), - distance=(175.4199417092314, "mm"), - ), +@device_factory() +def waxs() -> PilatusDetector: + metadata_holder = NXSasMetadataHolder( + x_pixel_size=(1.72e-1, "mm"), + y_pixel_size=(1.72e-1, "mm"), + description="Dectris Pilatus3 2M", + type="Photon Counting Hybrid Pixel", + sensor_material="silicon", + sensor_thickness=(0.45, "mm"), + distance=(175.4199417092314, "mm"), + ) + return NXSasPilatus( + prefix=f"{PREFIX.beamline_prefix}-EA-PILAT-03:", path_provider=get_path_provider(), + drv_suffix="CAM:", + hdf_suffix=HDF5_PREFIX, + metadata_holder=metadata_holder, ) -def i0( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> TetrammDetector: - return device_instantiation( - TetrammDetector, - "i0", - "-EA-XBPM-02:", - wait_for_connection, - fake_with_ophyd_sim, - type="Cividec Diamond XBPM", +@device_factory() +def i0() -> TetrammDetector: + return TetrammDetector( + prefix=f"{PREFIX.beamline_prefix}-EA-XBPM-02:", path_provider=get_path_provider(), + type="Cividec Diamond XBPM", ) -def it( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> TetrammDetector: - return device_instantiation( - TetrammDetector, - "it", - "-EA-TTRM-02:", - wait_for_connection, - fake_with_ophyd_sim, - type="PIN Diode", +@device_factory() +def it() -> TetrammDetector: + return TetrammDetector( + prefix=f"{PREFIX.beamline_prefix}-EA-TTRM-02:", path_provider=get_path_provider(), + type="PIN Diode", ) -def vfm( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> FocusingMirror: - return device_instantiation( - FocusingMirror, - "vfm", - "-OP-KBM-01:VFM:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory() +def vfm() -> FocusingMirror: + return FocusingMirror( + prefix=f"{PREFIX.beamline_prefix}-OP-KBM-01:VFM:", ) -def hfm( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> FocusingMirror: - return device_instantiation( - FocusingMirror, - "hfm", - "-OP-KBM-01:HFM:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory() +def hfm() -> FocusingMirror: + return FocusingMirror( + prefix=f"{PREFIX.beamline_prefix}-OP-KBM-01:HFM:", ) -def dcm( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> DoubleCrystalMonochromator: - return device_instantiation( - DoubleCrystalMonochromator, - "dcm", - f"{BeamlinePrefix(BL).beamline_prefix}-MO-DCM-01:", - wait_for_connection, - fake_with_ophyd_sim, - bl_prefix=False, - temperature_prefix=f"{BeamlinePrefix(BL).beamline_prefix}-DI-DCM-01:", +@device_factory() +def dcm() -> DoubleCrystalMonochromator: + return DoubleCrystalMonochromator( + prefix=f"{PREFIX.beamline_prefix}-MO-DCM-01:", + temperature_prefix=f"{PREFIX.beamline_prefix}-DI-DCM-01:", crystal_1_metadata=make_crystal_metadata_from_material( MaterialsEnum.Si, (1, 1, 1) ), crystal_2_metadata=make_crystal_metadata_from_material( - MaterialsEnum.Si, - (1, 1, 1), + MaterialsEnum.Si, (1, 1, 1) ), ) -def undulator( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> Undulator: - return device_instantiation( - Undulator, - "undulator", - f"{BeamlinePrefix(BL).insertion_prefix}-MO-SERVC-01:", - wait_for_connection, - fake_with_ophyd_sim, - bl_prefix=False, +@device_factory() +def undulator() -> Undulator: + return Undulator( + prefix=f"{PREFIX.insertion_prefix}-MO-SERVC-01:", + id_gap_lookup_table_path="/dls_sw/i22/software/daq_configuration/lookup/BeamLine_Undulator_toGap.txt", poles=80, length=2.0, - id_gap_lookup_table_path="/dls_sw/i22/software/daq_configuration/lookup/BeamLine_Undulator_toGap.txt", ) -def slits_1( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> Slits: - return numbered_slits( - 1, - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory() +def slits_1() -> Slits: + return Slits(prefix=f"{PREFIX.beamline_prefix}-AL-SLITS-01:") -def slits_2( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> Slits: - return numbered_slits( - 2, - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory() +def slits_2() -> Slits: + return Slits(prefix=f"{PREFIX.beamline_prefix}-AL-SLITS-02:") -def slits_3( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> Slits: - return numbered_slits( - 3, - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory() +def slits_3() -> Slits: + return Slits(prefix=f"{PREFIX.beamline_prefix}-AL-SLITS-03:") -def slits_4( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> Slits: - return numbered_slits( - 4, - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory() +def slits_4() -> Slits: + return Slits(prefix=f"{PREFIX.beamline_prefix}-AL-SLITS-04:") -def slits_5( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> Slits: - return numbered_slits( - 5, - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory() +def slits_5() -> Slits: + return Slits(prefix=f"{PREFIX.beamline_prefix}-AL-SLITS-05:") -def slits_6( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> Slits: - return numbered_slits( - 6, - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory() +def slits_6() -> Slits: + return Slits(prefix=f"{PREFIX.beamline_prefix}-AL-SLITS-06:") -def fswitch( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> FSwitch: - return device_instantiation( - FSwitch, - "fswitch", - "-MO-FSWT-01:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory() +def fswitch() -> FSwitch: + return FSwitch( + prefix=f"{PREFIX.beamline_prefix}-MO-FSWT-01:", lens_geometry="paraboloid", cylindrical=True, lens_material="Beryllium", @@ -286,107 +191,61 @@ def fswitch( # Must document what PandAs are physically connected to # See: https://github.com/bluesky/ophyd-async/issues/284 -def panda1( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> HDFPanda: - return device_instantiation( - HDFPanda, - "panda1", - "-EA-PANDA-01:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory() +def panda1() -> HDFPanda: + return HDFPanda( + prefix=f"{PREFIX.beamline_prefix}-EA-PANDA-01:", path_provider=get_path_provider(), ) -@skip_device() -def panda2( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> HDFPanda: - return device_instantiation( - HDFPanda, - "panda2", - "-EA-PANDA-02:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory(skip=True) +def panda2() -> HDFPanda: + return HDFPanda( + prefix=f"{PREFIX.beamline_prefix}-EA-PANDA-02:", path_provider=get_path_provider(), ) -@skip_device() -def panda3( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> HDFPanda: - return device_instantiation( - HDFPanda, - "panda3", - "-EA-PANDA-03:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory(skip=True) +def panda3() -> HDFPanda: + return HDFPanda( + prefix=f"{PREFIX.beamline_prefix}-EA-PANDA-03:", path_provider=get_path_provider(), ) -@skip_device() -def panda4( - wait_for_connection: bool = True, - fake_with_ophyd_sim: bool = False, -) -> HDFPanda: - return device_instantiation( - HDFPanda, - "panda4", - "-EA-PANDA-04:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory(skip=True) +def panda4() -> HDFPanda: + return HDFPanda( + prefix=f"{PREFIX.beamline_prefix}-EA-PANDA-04:", path_provider=get_path_provider(), ) -def oav( - wait_for_connection: bool = True, fake_with_ophyd_sim: bool = False -) -> AravisDetector: - return device_instantiation( - NXSasOAV, - "oav", - "-DI-OAV-01:", - wait_for_connection, - fake_with_ophyd_sim, +@device_factory() +def oav() -> AravisDetector: + metadata_holder = NXSasMetadataHolder( + x_pixel_size=(3.45e-3, "mm"), # Double check this figure + y_pixel_size=(3.45e-3, "mm"), + description="AVT Mako G-507B", + distance=(-1.0, "m"), + ) + return NXSasOAV( + prefix=f"{PREFIX.beamline_prefix}-DI-OAV-01:", drv_suffix="DET:", - hdf_suffix="HDF5:", - metadata_holder=NXSasMetadataHolder( - x_pixel_size=(3.45e-3, "mm"), # Double check this figure - y_pixel_size=(3.45e-3, "mm"), - description="AVT Mako G-507B", - distance=(-1.0, "m"), - ), + hdf_suffix=HDF5_PREFIX, path_provider=get_path_provider(), + metadata_holder=metadata_holder, ) -@skip_device() -def linkam( - wait_for_connection: bool = True, fake_with_ophyd_sim: bool = False -) -> Linkam3: - return device_instantiation( - Linkam3, - "linkam", - "-EA-TEMPC-05", - wait_for_connection, - fake_with_ophyd_sim, - ) +@device_factory(skip=True) +def linkam() -> Linkam3: + return Linkam3(prefix=f"{PREFIX.beamline_prefix}-EA-TEMPC-05:") -def ppump( - wait_for_connection: bool = True, fake_with_ophyd_sim: bool = False -) -> WatsonMarlow323Pump: +@device_factory() +def ppump() -> WatsonMarlow323Pump: """Sample Environment Peristaltic Pump""" - return device_instantiation( - WatsonMarlow323Pump, - "ppump", - "-EA-PUMP-01:", - wait_for_connection, - fake_with_ophyd_sim, - ) + return WatsonMarlow323Pump(f"{PREFIX.beamline_prefix}-EA-PUMP-01:") diff --git a/src/dodal/common/beamlines/beamline_utils.py b/src/dodal/common/beamlines/beamline_utils.py index c9ec043192..aa0eb052f7 100644 --- a/src/dodal/common/beamlines/beamline_utils.py +++ b/src/dodal/common/beamlines/beamline_utils.py @@ -1,15 +1,23 @@ import inspect from collections.abc import Callable -from typing import Final, TypeVar, cast +from typing import Annotated, Final, TypeVar, cast from bluesky.run_engine import call_in_bluesky_event_loop from ophyd import Device as OphydV1Device from ophyd.sim import make_fake_device +from ophyd_async.core import DEFAULT_TIMEOUT from ophyd_async.core import Device as OphydV2Device from ophyd_async.core import wait_for_connection as v2_device_wait_for_connection from dodal.common.types import UpdatingPathProvider -from dodal.utils import AnyDevice, BeamlinePrefix, skip_device +from dodal.utils import ( + AnyDevice, + BeamlinePrefix, + D, + DeviceInitializationController, + SkipType, + skip_device, +) DEFAULT_CONNECTION_TIMEOUT: Final[float] = 5.0 @@ -124,6 +132,28 @@ def device_instantiation( return device_instance +def device_factory( + *, + use_factory_name: Annotated[bool, "Use factory name as name of device"] = True, + timeout: Annotated[float, "Timeout for connecting to the device"] = DEFAULT_TIMEOUT, + mock: Annotated[bool, "Use Signals with mock backends for device"] = False, + skip: Annotated[ + SkipType, + "mark the factory to be (conditionally) skipped when beamline is imported by external program", + ] = False, +) -> Callable[[Callable[[], D]], DeviceInitializationController[D]]: + def decorator(factory: Callable[[], D]) -> DeviceInitializationController[D]: + return DeviceInitializationController( + factory, + use_factory_name, + timeout, + mock, + skip, + ) + + return decorator + + def set_path_provider(provider: UpdatingPathProvider): global PATH_PROVIDER diff --git a/src/dodal/common/beamlines/device_helpers.py b/src/dodal/common/beamlines/device_helpers.py index 8e699361aa..43c3c88e4c 100644 --- a/src/dodal/common/beamlines/device_helpers.py +++ b/src/dodal/common/beamlines/device_helpers.py @@ -2,6 +2,8 @@ from dodal.devices.slits import Slits from dodal.utils import skip_device +HDF5_PREFIX = "HDF5:" + @skip_device() def numbered_slits( diff --git a/src/dodal/devices/focusing_mirror.py b/src/dodal/devices/focusing_mirror.py index b2f90cc083..a8eecddfdc 100644 --- a/src/dodal/devices/focusing_mirror.py +++ b/src/dodal/devices/focusing_mirror.py @@ -130,7 +130,12 @@ class FocusingMirror(StandardReadable): """Focusing Mirror""" def __init__( - self, name, prefix, bragg_to_lat_lut_path=None, x_suffix="X", y_suffix="Y" + self, + prefix: str, + name: str = "", + bragg_to_lat_lut_path: str | None = None, + x_suffix: str = "X", + y_suffix: str = "Y", ): self.bragg_to_lat_lookup_table_path = bragg_to_lat_lut_path self.yaw_mrad = Motor(prefix + "YAW") @@ -161,12 +166,12 @@ class FocusingMirrorWithStripes(FocusingMirror): """A focusing mirror where the stripe material can be changed. This is usually done based on the energy of the beamline.""" - def __init__(self, name, prefix, *args, **kwargs): + def __init__(self, prefix: str, name: str = "", *args, **kwargs): self.stripe = epics_signal_rw(MirrorStripe, prefix + "STRP:DVAL") # apply the current set stripe setting self.apply_stripe = epics_signal_x(prefix + "CHANGE.PROC") - super().__init__(name, prefix, *args, **kwargs) + super().__init__(prefix, name, *args, **kwargs) def energy_to_stripe(self, energy_kev) -> MirrorStripe: # In future, this should be configurable per-mirror diff --git a/src/dodal/devices/linkam3.py b/src/dodal/devices/linkam3.py index f43635f00a..125ce5e3e6 100644 --- a/src/dodal/devices/linkam3.py +++ b/src/dodal/devices/linkam3.py @@ -33,7 +33,7 @@ class Linkam3(StandardReadable): tolerance: float = 0.5 settle_time: int = 0 - def __init__(self, prefix: str, name: str): + def __init__(self, prefix: str, name: str = ""): self.temp = epics_signal_r(float, prefix + "TEMP:") self.dsc = epics_signal_r(float, prefix + "DSC:") self.start_heat = epics_signal_rw(bool, prefix + "STARTHEAT:") diff --git a/src/dodal/devices/tetramm.py b/src/dodal/devices/tetramm.py index c5b75b1ed3..7e39027aae 100644 --- a/src/dodal/devices/tetramm.py +++ b/src/dodal/devices/tetramm.py @@ -219,7 +219,7 @@ def __init__( self, prefix: str, path_provider: PathProvider, - name: str, + name: str = "", type: str | None = None, **scalar_sigs: str, ) -> None: diff --git a/src/dodal/utils.py b/src/dodal/utils.py index af1b2cb3bd..f2f8a29d74 100644 --- a/src/dodal/utils.py +++ b/src/dodal/utils.py @@ -1,3 +1,4 @@ +import functools import importlib import inspect import os @@ -6,13 +7,14 @@ import string from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass -from functools import wraps +from functools import update_wrapper, wraps from importlib import import_module from inspect import signature from os import environ from types import ModuleType from typing import ( Any, + Generic, Protocol, TypeGuard, TypeVar, @@ -35,6 +37,7 @@ Triggerable, WritesExternalAssets, ) +from bluesky.run_engine import call_in_bluesky_event_loop from ophyd.device import Device as OphydV1Device from ophyd_async.core import Device as OphydV2Device @@ -99,6 +102,8 @@ def __post_init__(self): T = TypeVar("T", bound=AnyDevice) +D = TypeVar("D", bound=OphydV2Device) +SkipType = bool | Callable[[], bool] def skip_device(precondition=lambda: True): @@ -114,6 +119,91 @@ def wrapper(*args, **kwds) -> T: return decorator +class DeviceInitializationController(Generic[D]): + def __init__( + self, + factory: Callable[[], D], + use_factory_name: bool, + timeout: float, + mock: bool, + skip: SkipType, + ): + self._factory: Callable[[], D] = functools.cache(factory) + self._use_factory_name = use_factory_name + self._timeout = timeout + self._mock = mock + self._skip = skip + update_wrapper(self, factory) + + @property + def skip(self) -> bool: + return self._skip() if callable(self._skip) else self._skip + + def cache_clear(self) -> None: + """Clears the controller's internal cached instance of the device, if present. + Noop if not.""" + + # Functools adds the cache_clear function via setattr so the type checker + # does not pick it up. + self._factory.cache_clear() # type: ignore + + def __call__( + self, + connect_immediately: bool = False, + name: str | None = None, + connection_timeout: float | None = None, + mock: bool | None = None, + ) -> D: + """Returns an instance of the Device the wrapped factory produces: the same + instance will be returned if this method is called multiple times, and arguments + may be passed to override this Controller's configuration. + Once the device is connected, the value of mock must be consistent, or connect + must be False. + + + Args: + connect_immediately (bool, default False): whether to call connect on the + device before returning it- connect is idempotent for ophyd-async devices. + Not connecting to the device allows for the instance to be created prior + to the RunEngine event loop being configured or for connect to be called + lazily e.g. by the `ensure_connected` stub. + name (str | None, optional): an override name to give the device, which is + also used to name its children. Defaults to None, which does not name the + device unless the device has no name and this Controller is configured to + use_factory_name, which propagates the name of the wrapped factory + function to the device instance. + connection_timeout (float | None, optional): an override timeout length in + seconds for the connect method, if it is called. Defaults to None, which + defers to the timeout configured for this Controller: the default uses + ophyd_async's DEFAULT_TIMEOUT. + mock (bool | None, optional): overrides whether to connect to Mock signal + backends, if connect is called. Defaults to None, which uses the mock + parameter of this Controller. This value must be used consistently when + connect is called on the Device. + + Returns: + D: a singleton instance of the Device class returned by the wrapped factory. + """ + device = self._factory() + + if connect_immediately: + call_in_bluesky_event_loop( + device.connect( + timeout=connection_timeout + if connection_timeout is not None + else self._timeout, + mock=mock if mock is not None else self._mock, + ) + ) + + if name: + device.set_name(name) + elif not device.name and self._use_factory_name: + device.set_name(self._factory.__name__) + + return device + + def make_device( module: str | ModuleType, device_name: str, @@ -206,7 +296,33 @@ def invoke_factories( dependent_name = leaves.pop() params = {name: devices[name] for name in dependencies[dependent_name]} try: - devices[dependent_name] = factories[dependent_name](**params, **kwargs) + factory = factories[dependent_name] + if isinstance(factory, DeviceInitializationController): + # For now we translate the old-style parameters that + # device_instantiation expects. Once device_instantiation is gone and + # replaced with DeviceInitializationController we can formalise the + # API of make_all_devices and make these parameters explicit. + # https://github.com/DiamondLightSource/dodal/issues/844 + mock = kwargs.get( + "mock", + kwargs.get( + "fake_with_ophyd_sim", + False, + ), + ) + connect_immediately = kwargs.get( + "connect_immediately", + kwargs.get( + "wait_for_connection", + False, + ), + ) + devices[dependent_name] = factory( + mock=mock, + connect_immediately=connect_immediately, + ) + else: + devices[dependent_name] = factory(**params, **kwargs) except Exception as e: exceptions[dependent_name] = e @@ -268,6 +384,8 @@ def collect_factories( def _is_device_skipped(func: AnyDeviceFactory) -> bool: + if isinstance(func, DeviceInitializationController): + return func.skip return getattr(func, "__skip__", False) diff --git a/tests/common/beamlines/test_beamline_utils.py b/tests/common/beamlines/test_beamline_utils.py index 8e33eee798..58256e59bc 100644 --- a/tests/common/beamlines/test_beamline_utils.py +++ b/tests/common/beamlines/test_beamline_utils.py @@ -1,4 +1,5 @@ import asyncio +import functools from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest @@ -13,11 +14,12 @@ from dodal.beamlines import i03 from dodal.common.beamlines import beamline_utils from dodal.devices.eiger import EigerDetector +from dodal.devices.focusing_mirror import FocusingMirror from dodal.devices.motors import XYZPositioner from dodal.devices.smargon import Smargon from dodal.devices.zebra import Zebra from dodal.log import LOGGER -from dodal.utils import make_all_devices +from dodal.utils import DeviceInitializationController, make_all_devices @pytest.fixture(autouse=True) @@ -133,3 +135,83 @@ def test_wait_for_v2_device_connection_passes_through_timeout(kwargs, expected_t mock=ANY, timeout=expected_timeout, ) + + +def dummy_mirror() -> FocusingMirror: + mirror = MagicMock(spec=FocusingMirror) + connect = AsyncMock() + mirror.connect = connect + + def set_name(name: str): + mirror.name = name # type: ignore + + mirror.set_name.side_effect = set_name + mirror.set_name("") + return mirror + + +@beamline_utils.device_factory(mock=True) +def dummy_mirror_as_device_factory() -> FocusingMirror: + return dummy_mirror() + + +@beamline_utils.device_factory(mock=True) +@functools.lru_cache +def cached_dummy_mirror_as_device_factory() -> FocusingMirror: + return dummy_mirror() + + +def test_device_controller_name_propagated(): + mirror = dummy_mirror_as_device_factory(name="foo") + assert mirror.name == "foo" + + +def test_device_controller_connection_is_lazy(): + mirror = dummy_mirror_as_device_factory(name="foo") + assert mirror.connect.call_count == 0 # type: ignore + + +def test_device_controller_eager_connect(RE): + mirror = dummy_mirror_as_device_factory(connect_immediately=True) + assert mirror.connect.call_count == 1 # type: ignore + + +@pytest.mark.parametrize( + "factory", + [ + dummy_mirror_as_device_factory, + # The second test case confirms that if, for some reason, we use a device + # factory decorated with @lru_cache, dodal is not affected and will still cache + # the same device instance internally. We actually also use lru_cache + # internally so this test case is just a sanity check to prove it is + # idempotent. + cached_dummy_mirror_as_device_factory, + ], +) +def test_device_cached(factory: DeviceInitializationController): + mirror_1 = factory() + mirror_2 = factory() + assert mirror_1 is mirror_2 + + +def test_device_cache_can_be_cleared(): + mirror_1 = dummy_mirror_as_device_factory() + dummy_mirror_as_device_factory.cache_clear() + + mirror_2 = dummy_mirror_as_device_factory() + assert mirror_1 is not mirror_2 + + +def test_skip(RE): + skip = True + + def _skip() -> bool: + return skip + + controller = beamline_utils.device_factory(skip=_skip)(dummy_mirror) + + assert isinstance(controller, DeviceInitializationController) + assert controller.skip + + skip = False + assert not controller.skip diff --git a/tests/common/beamlines/test_device_instantiation.py b/tests/common/beamlines/test_device_instantiation.py index eac84157f5..4584f98f76 100644 --- a/tests/common/beamlines/test_device_instantiation.py +++ b/tests/common/beamlines/test_device_instantiation.py @@ -1,9 +1,9 @@ from typing import Any import pytest +from ophyd_async.core import NotConnected from dodal.beamlines import all_beamline_modules -from dodal.common.beamlines import beamline_utils from dodal.utils import BLUESKY_PROTOCOLS, make_all_devices @@ -21,15 +21,17 @@ def test_device_creation(RE, module_and_devices_for_beamline): Ensures that for every beamline all device factories are using valid args and creating types that conform to Bluesky protocols. """ - module, devices, exceptions = module_and_devices_for_beamline - assert not exceptions - for device_name, device in devices.items(): - assert device_name in beamline_utils.ACTIVE_DEVICES, ( - f"No device named {device_name} was created for {module}, " - f"devices are {beamline_utils.ACTIVE_DEVICES.keys()}" - ) - assert follows_bluesky_protocols(device) - assert len(beamline_utils.ACTIVE_DEVICES) == len(devices) + _, devices, exceptions = module_and_devices_for_beamline + if len(exceptions) > 0: + raise NotConnected(exceptions) + devices_not_following_bluesky_protocols = [ + name + for name, device in devices.items() + if not follows_bluesky_protocols(device) + ] + assert ( + len(devices_not_following_bluesky_protocols) == 0 + ), f"{devices_not_following_bluesky_protocols} do not follow bluesky protocols" @pytest.mark.parametrize( @@ -47,5 +49,13 @@ def test_devices_are_identical(RE, module_and_devices_for_beamline): include_skipped=True, fake_with_ophyd_sim=True, ) - for device_name in devices_a.keys(): - assert devices_a[device_name] is devices_b[device_name] + non_identical_names = [ + device_name + for device_name, device in devices_a.items() + if device is not devices_b[device_name] + ] + total_number_of_devices = len(devices_a) + non_identical_number_of_devies = len(devices_a) + assert ( + len(non_identical_names) == 0 + ), f"{non_identical_number_of_devies}/{total_number_of_devices} devices were not identical: {non_identical_names}" diff --git a/tests/fake_device_factory_beamline.py b/tests/fake_device_factory_beamline.py new file mode 100644 index 0000000000..3ab9e7a08d --- /dev/null +++ b/tests/fake_device_factory_beamline.py @@ -0,0 +1,34 @@ +from unittest.mock import AsyncMock, MagicMock + +from bluesky.protocols import Readable, Reading, SyncOrAsync +from event_model.documents.event_descriptor import DataKey +from ophyd_async.core import Device + +from dodal.common.beamlines.beamline_utils import device_factory +from dodal.devices.cryostream import CryoStream + + +class ReadableDevice(Readable, Device): + def read(self) -> SyncOrAsync[dict[str, Reading]]: + return {} + + def describe(self) -> SyncOrAsync[dict[str, DataKey]]: + return {} + + +@device_factory(skip=True) +def device_a() -> ReadableDevice: + return ReadableDevice("readable") + + +@device_factory(skip=lambda: True) +def device_c() -> CryoStream: + return CryoStream("FOO:") + + +@device_factory(skip=True) +def mock_device() -> ReadableDevice: + device = MagicMock() + device.name = "mock_device" + device.connect = AsyncMock() + return device # type: ignore diff --git a/tests/test_utils.py b/tests/test_utils.py index 979e3e3cd4..e6e2afdae9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,11 @@ import os from collections.abc import Iterable, Mapping -from unittest.mock import MagicMock, patch +from typing import cast +from unittest.mock import ANY, MagicMock, Mock, patch import pytest from bluesky.protocols import Readable +from bluesky.run_engine import RunEngine from ophyd import EpicsMotor from dodal.beamlines import i03, i23 @@ -136,6 +138,85 @@ def test_make_device_dependency_throws(): make_device(fake_beamline, "device_z") +def test_device_factory_skips(): + import tests.fake_device_factory_beamline as fake_beamline + + devices, exceptions = make_all_devices(fake_beamline) + assert len(devices) == 0 + assert len(exceptions) == 0 + + +def test_device_factory_can_ignore_skip(): + import tests.fake_device_factory_beamline as fake_beamline + + devices, exceptions = make_all_devices(fake_beamline, include_skipped=True) + assert len(devices) == 3 + assert len(exceptions) == 0 + + +def test_fake_with_ophyd_sim_passed_to_device_factory(RE: RunEngine): + import tests.fake_device_factory_beamline as fake_beamline + + fake_beamline.mock_device.cache_clear() + + devices, exceptions = make_all_devices( + fake_beamline, + include_skipped=True, + fake_with_ophyd_sim=True, + connect_immediately=True, + ) + if "mock_device" in exceptions: + raise exceptions["mock_device"] + mock_device = cast(Mock, devices["mock_device"]) + mock_device.connect.assert_called_once_with(timeout=ANY, mock=True) + + +def test_mock_passed_to_device_factory(RE: RunEngine): + import tests.fake_device_factory_beamline as fake_beamline + + fake_beamline.mock_device.cache_clear() + + devices, exceptions = make_all_devices( + fake_beamline, + include_skipped=True, + mock=True, + connect_immediately=True, + ) + if "mock_device" in exceptions: + raise exceptions["mock_device"] + mock_device = cast(Mock, devices["mock_device"]) + mock_device.connect.assert_called_once_with(timeout=ANY, mock=True) + + +def test_connect_immediately_passed_to_device_factory(RE: RunEngine): + import tests.fake_device_factory_beamline as fake_beamline + + fake_beamline.mock_device.cache_clear() + + devices, exceptions = make_all_devices( + fake_beamline, + include_skipped=True, + connect_immediately=False, + ) + if "mock_device" in exceptions: + raise exceptions["mock_device"] + mock_device = cast(Mock, devices["mock_device"]) + mock_device.connect.assert_not_called() + + +def test_device_factory_can_rename(RE): + from tests.fake_device_factory_beamline import device_c + + cryo = device_c(mock=True, connect_immediately=True) + assert cryo.name == "device_c" + assert cryo.fine.name == "device_c-fine" + + cryo_2 = device_c(name="cryo") + assert cryo is cryo_2 + assert cryo_2.name == "cryo" + assert cryo_2.fine.name == "cryo-fine" + + def device_a() -> Readable: return MagicMock()