From 010bb45126bc37ff84348e65fb071928c0951c40 Mon Sep 17 00:00:00 2001 From: Mark Wolfman Date: Mon, 21 Oct 2024 12:31:10 -0500 Subject: [PATCH 1/2] Copied and modified original implementation from Haven. --- README.rst | 9 +- pyproject.toml | 6 +- pytest.ini | 4 + src/guarneri/__init__.py | 8 +- src/guarneri/_version.py | 154 +++++++++----- src/guarneri/exceptions.py | 4 + src/guarneri/iconfig_example.toml | 23 ++ src/guarneri/instrument.py | 296 ++++++++++++++++++++++++++ src/guarneri/tests/test_instrument.py | 128 +++++++++++ 9 files changed, 567 insertions(+), 65 deletions(-) create mode 100644 pytest.ini create mode 100644 src/guarneri/exceptions.py create mode 100644 src/guarneri/iconfig_example.toml create mode 100644 src/guarneri/instrument.py create mode 100644 src/guarneri/tests/test_instrument.py diff --git a/README.rst b/README.rst index d03ef45..f8b7d7a 100644 --- a/README.rst +++ b/README.rst @@ -145,7 +145,14 @@ The following will download the package and load it into the python environment. .. code-block:: bash git clone https://github.com/spc-group/guarneri - pip install -e guarneri + pip install guarneri + +For development of guarneri, install as an editable project with all +development dependencies using: + +.. code-block:: bash + + pip install -e ".[dev]" Running the Tests diff --git a/pyproject.toml b/pyproject.toml index c4ead5c..0c05353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "guarneri" readme = "README.rst" -requires-python = ">=3.7" +requires-python = ">=3.10" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", @@ -19,10 +19,10 @@ classifiers = [ "Topic :: System :: Hardware", ] keywords = [] -dependencies = [] +dependencies = ["ophyd", "ophyd-async", "ophyd-registry", "tomlkit"] [project.optional-dependencies] -dev = ["black", "pytest", "flake8", "isort"] +dev = ["black", "pytest", "pytest-asyncio", "flake8", "isort"] [project.urls] Homepage = "https://github.com/spc-group/guarneri" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..afedc2c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +asyncio_mode = auto +testpaths = + src/guarneri/tests diff --git a/src/guarneri/__init__.py b/src/guarneri/__init__.py index 21fb152..eebdc65 100644 --- a/src/guarneri/__init__.py +++ b/src/guarneri/__init__.py @@ -1,7 +1,11 @@ from ._version import get_versions -__version__ = get_versions()['version'] +__version__ = get_versions()["version"] del get_versions # TODO: fill this in with appropriate star imports: -__all__ = [] +__all__ = ["Instrument", "exceptions"] + + +from . import exceptions +from .instrument import Instrument diff --git a/src/guarneri/_version.py b/src/guarneri/_version.py index b2fd1e1..478275c 100644 --- a/src/guarneri/_version.py +++ b/src/guarneri/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -58,17 +57,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -76,10 +76,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -116,16 +119,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -181,7 +190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -190,7 +199,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -198,19 +207,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -225,8 +241,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -234,10 +249,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -260,17 +284,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -279,10 +302,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -293,13 +318,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -330,8 +355,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -445,11 +469,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -469,9 +495,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -485,8 +515,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -495,13 +524,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for i in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -515,6 +547,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/src/guarneri/exceptions.py b/src/guarneri/exceptions.py new file mode 100644 index 0000000..6029e00 --- /dev/null +++ b/src/guarneri/exceptions.py @@ -0,0 +1,4 @@ +class InvalidConfiguration(TypeError): + """The configuration files for Haven are missing keys.""" + + ... diff --git a/src/guarneri/iconfig_example.toml b/src/guarneri/iconfig_example.toml new file mode 100644 index 0000000..ce866a2 --- /dev/null +++ b/src/guarneri/iconfig_example.toml @@ -0,0 +1,23 @@ +[beamline] +name = "APS Beamline (sector unknown)" +# Whether to connect to hardware (true) or use mocked signals (false) +hardware_is_present = false + + +[[ async_device ]] +scaler_prefix = "255idcVME:3820:" +scaler_channel = 2 +preamp_prefix = "255idc:SR03:" +voltmeter_prefix = "255idc:LabJackT7_1:" +voltmeter_channel = 1 +# From V2F100: Fmax / Vmax +counts_per_volt_second = 10e6 +name = "I0" + + +[[ threaded_device ]] +prefix = "" +name = "" + +[[ factory_device ]] +num_devices = 5 \ No newline at end of file diff --git a/src/guarneri/instrument.py b/src/guarneri/instrument.py new file mode 100644 index 0000000..008b3eb --- /dev/null +++ b/src/guarneri/instrument.py @@ -0,0 +1,296 @@ +"""Loader for creating instances of the devices from a config file.""" + +import asyncio +import inspect +import logging +import os +import time +from pathlib import Path +from typing import Mapping, Sequence + +import tomlkit +from ophyd import Device as ThreadedDevice +from ophyd.sim import make_fake_device +from ophyd_async.core import DEFAULT_TIMEOUT, NotConnected +from ophydregistry import Registry + +from .exceptions import InvalidConfiguration + +log = logging.getLogger(__name__) + + +instrument = None + + +class Instrument: + """A beamline instrument built from config files of Ophyd devices. + + *device_classes* should be dictionary that maps configuration + section names to device classes (or similar items). + + Example: + + ```python + instrument = Instrument({ + "ion_chamber": IonChamber + "motors": load_motors, + }) + ``` + + The values in *device_classes* should be one of the following: + + 1. A device class + 2. A callable that returns an instantiated device object + 3. A callable that returns a sequence of device objects + + Parameters + ========== + device_classes + Maps config section names to device classes. + + """ + + devices: list + registry: Registry + beamline_name: str = "" + hardware_is_present: bool | None = None + + def __init__(self, device_classes: Mapping, registry: Registry | None = None): + self.devices = [] + if registry is None: + registry = Registry(auto_register=False, use_typhos=False) + self.registry = registry + self.device_classes = device_classes + + def parse_toml_file(self, fd): + """Parse TOML instrument configuration and create devices. + + An open file descriptor + """ + config = tomlkit.load(fd) + # Set global parameters + beamline = config.get("beamline", {}) + self.beamline_name = beamline.get("name", self.beamline_name) + self.hardware_is_present = beamline.get( + "hardware_is_present", self.hardware_is_present + ) + # Make devices from config file + return self.parse_config(config) + + def parse_config(self, cfg): + devices = [] + for key, Klass in self.device_classes.items(): + # Create the devices + for params in cfg.get(key, []): + print(Klass, params) + self.validate_params(params, Klass) + device = self.make_device(params, Klass) + try: + # Maybe its a list of devices? + devices.extend(device) + except TypeError: + # No, assume it's just a single device then + devices.append(device) + # Save devices for connecting to later + self.devices.extend(devices) + return devices + + def validate_params(self, params, Klass): + """Check that parameters match a Device class's initializer.""" + sig = inspect.signature(Klass) + has_kwargs = any( + [param.kind == param.VAR_KEYWORD for param in sig.parameters.values()] + ) + # Make sure we're not missing any required parameters + for key, sig_param in sig.parameters.items(): + # Check for missing parameters + param_missing = key not in params + param_required = ( + sig_param.default is sig_param.empty + and sig_param.kind != sig_param.VAR_KEYWORD + ) + if param_missing and param_required: + raise InvalidConfiguration( + f"Missing required key '{key}' for {Klass}: {params}" + ) + # Check types + if not param_missing: + try: + correct_type = isinstance(params[key], sig_param.annotation) + has_type = not issubclass(sig_param.annotation, inspect._empty) + except TypeError: + correct_type = False + has_type = False + if has_type and not correct_type: + raise InvalidConfiguration( + f"Incorrect type for {Klass} key '{key}': " + f"expected `{sig_param.annotation}` but got " + f"`{type(params[key])}`." + ) + + def make_device(self, params, Klass): + """Create the devices from their parameters.""" + # Mock threaded ophyd devices if necessary + try: + is_threaded_device = issubclass(Klass, ThreadedDevice) + except TypeError: + is_threaded_device = False + if is_threaded_device and not self.hardware_is_present: + Klass = make_fake_device(Klass) + # Turn the parameters into pure python objects + kwargs = {} + for key, param in params.items(): + if isinstance(param, tomlkit.items.Item): + kwargs[key] = param.unwrap() + else: + kwargs[key] = param + # Check if we need to injec the registry + extra_params = {} + sig = inspect.signature(Klass) + if "registry" in sig.parameters.keys(): + kwargs["registry"] = self.registry + # Create the device + result = Klass(**kwargs) + return result + + async def connect( + self, + mock: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect: bool = False, + return_exceptions: bool = False, + ): + """Connect all Devices. + + Contains a timeout that gets propagated to device.connect methods. + + Parameters + ---------- + mock: + If True then use ``MockSignalBackend`` for all Signals + timeout: + Time to wait before failing with a TimeoutError. + force_reconnect + Force the signals to establish a new connection. + """ + t0 = time.monotonic() + # Sort out which devices are which + threaded_devices = [] + async_devices = [] + for device in self.devices: + if hasattr(device, "connect"): + async_devices.append(device) + else: + threaded_devices.append(device) + # Connect to async devices + aws = ( + dev.connect(mock=mock, timeout=timeout, force_reconnect=force_reconnect) + for dev in async_devices + ) + results = await asyncio.gather(*aws, return_exceptions=True) + # Filter out the disconnected devices + new_devices = [] + exceptions = {} + for device, result in zip(async_devices, results): + if result is None: + log.debug(f"Successfully connected device {device.name}") + new_devices.append(device) + else: + # Unexpected exception, raise it so it can be handled + log.debug(f"Failed connection for device {device.name}") + exceptions[device.name] = result + # Connect to threaded devices + timeout_reached = False + while not timeout_reached and len(threaded_devices) > 0: + # Remove any connected devices for the running list + connected_devices = [ + dev for dev in threaded_devices if getattr(dev, "connected", True) + ] + new_devices.extend(connected_devices) + threaded_devices = [ + dev for dev in threaded_devices if dev not in connected_devices + ] + # Tick the clock for the next round through the while loop + await asyncio.sleep(min((0.05, timeout / 10.0))) + timeout_reached = (time.monotonic() - t0) > timeout + # Add disconnected devices to the exception list + for device in threaded_devices: + try: + device.wait_for_connection(timeout=0) + except TimeoutError as exc: + exceptions[device.name] = NotConnected(str(exc)) + # Raise exceptions if any were present + if return_exceptions: + return new_devices, exceptions + if len(exceptions) > 0: + raise NotConnected(exceptions) + return new_devices + + async def load( + self, + connect: bool = True, + device_classes: Mapping | None = None, + config_files: Sequence[Path] | None = None, + return_exceptions: bool = False, + ): + """Load instrument specified in config files. + + Unless, explicitly overridden by the *config_files* argument, + configuration files are read from the environmental variable + HAVEN_CONFIG_FILES (separated by ':'). + + Parameters + ========== + connect + If true, establish connections for the devices now. + device_classes + A temporary set of device classes to use for this call + only. Overrides any device classes given during + initalization. + config_files + I list of file paths that will be loaded. If omitted, those + files listed in HAVEN_CONFIG_FILES will be used. + return_exceptions + If true, exceptions will be returned for further processing, + otherwise, exceptions will be raised (default). + + """ + self.devices = [] + # Decide which config files to use + if config_files is None: + env_key = "HAVEN_CONFIG_FILES" + if env_key in os.environ.keys(): + config_files = os.environ.get("HAVEN_CONFIG_FILES", "") + config_files = [Path(fp) for fp in config_files.split(":")] + else: + config_files = [ + Path(__file__).parent.resolve() / "iconfig_testing.toml" + ] + # Load the instrument from config files + old_classes = self.device_classes + try: + # Temprary override of device classes + if device_classes is not None: + self.device_classes = device_classes + # Parse TOML files + for fp in config_files: + with open(fp, mode="tr", encoding="utf-8") as fd: + self.parse_toml_file(fd) + finally: + self.device_classes = old_classes + # Connect the devices + if connect: + new_devices, exceptions = await self.connect( + mock=not self.hardware_is_present, return_exceptions=True + ) + else: + new_devices = self.devices + exceptions = [] + # Registry devices + for device in new_devices: + self.registry.register(device) + # Raise exceptions + if return_exceptions: + return exceptions + elif len(exceptions) > 0: + raise NotConnected(exceptions) diff --git a/src/guarneri/tests/test_instrument.py b/src/guarneri/tests/test_instrument.py new file mode 100644 index 0000000..22d447e --- /dev/null +++ b/src/guarneri/tests/test_instrument.py @@ -0,0 +1,128 @@ +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest +from ophyd import Component +from ophyd import Device as DeviceV1 +from ophyd import EpicsSignal +from ophyd_async.core import Device + +from guarneri import Instrument, exceptions + +toml_file = Path(__file__).parent.parent.resolve() / "iconfig_example.toml" + + +class ThreadedDevice(DeviceV1): + description = Component(EpicsSignal, ".DESC") + + +class AsyncDevice(Device): + def __init__( + self, + scaler_prefix: str, + scaler_channel: int, + preamp_prefix: str, + voltmeter_prefix: str, + voltmeter_channel: int, + counts_per_volt_second: float, + name="", + auto_name: bool = None, + ): + super().__init__(name=name) + + +def load_devices(num_devices=0): + for i in num_devices: + yield AsyncDevice() + + +@pytest.fixture() +def instrument(): + inst = Instrument( + { + "async_device": AsyncDevice, + "factory_device": load_devices, + "threaded_device": ThreadedDevice, + } + ) + with open(toml_file, mode="tr", encoding="utf-8") as fd: + inst.parse_toml_file(fd) + return inst + + +def test_global_parameters(instrument): + """Check that we loaded keys that apply to the whole beamline.""" + assert instrument.beamline_name == "APS Beamline (sector unknown)" + assert instrument.hardware_is_present == False + + +def test_validate_missing_params(instrument): + defn = { + # "scaler_prefix": "scaler_1:", + # "scaler_channel": 3, + # "preamp_prefix": "preamp_1:", + # "voltmeter_prefix": "labjack_1:", + # "voltmeter_channel": 1, + # "counts_per_volt_second": 1e-6, + # "name": "", + # "auto_name": None, + } + with pytest.raises(Exception): + instrument.validate_params(defn, AsyncDevice) + + +def test_validate_optional_params(instrument): + defn = { + "scaler_prefix": "scaler_1:", + "scaler_channel": 3, + "preamp_prefix": "preamp_1:", + "voltmeter_prefix": "labjack_1:", + "voltmeter_channel": 1, + "counts_per_volt_second": 1e-6, + # "name": "", + # "auto_name": None, + } + instrument.validate_params(defn, AsyncDevice) + + +def test_validate_wrong_types(instrument): + defn = { + "scaler_prefix": "scaler_1:", + "scaler_channel": "3", + "preamp_prefix": "preamp_1:", + "voltmeter_prefix": "labjack_1:", + "voltmeter_channel": "1", + "counts_per_volt_second": 1e-6, + "name": "", + "auto_name": None, + } + with pytest.raises(exceptions.InvalidConfiguration): + instrument.validate_params(defn, AsyncDevice) + + +async def test_connect(instrument): + async_devices = [d for d in instrument.devices if hasattr(d, "_connect_task")] + sync_devices = [d for d in instrument.devices if hasattr(d, "connected")] + assert len(async_devices) > 0 + assert len(sync_devices) > 0 + # Are devices disconnected to start with? + assert all([d._connect_task is None for d in async_devices]) + assert all([not d.connected is None for d in sync_devices]) + # Connect the device + await instrument.connect(mock=True) + # Are devices connected afterwards? + # NB: This doesn't actually test the code for threaded devices + assert all([d._connect_task.done for d in async_devices]) + + +async def test_load(monkeypatch): + instrument = Instrument({}) + # Mock out the relevant methods to test + monkeypatch.setattr(instrument, "parse_toml_file", MagicMock()) + monkeypatch.setattr(instrument, "connect", AsyncMock(return_value=([], []))) + monkeypatch.setenv("HAVEN_CONFIG_FILES", str(toml_file), prepend=False) + # Execute the loading step + await instrument.load() + # Check that the right methods were called + instrument.parse_toml_file.assert_called_once() + instrument.connect.assert_called_once_with(mock=True, return_exceptions=True) From 4b162f55d56b0da2130af066d5e6a34771012238 Mon Sep 17 00:00:00 2001 From: Mark Wolfman Date: Mon, 21 Oct 2024 12:50:42 -0500 Subject: [PATCH 2/2] Fixed CI definition. --- .github/workflows/ci.yml | 72 +++++++++++++++++++--------------------- pyproject.toml | 2 +- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bab01b8..c08d534 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,43 +1,41 @@ -# Based on tutorial: https://autobencoder.com/2020-08-24-conda-actions/ +# https://docs.github.com/en/actions/use-cases-and-examples/building-and-testing/building-and-testing-python -name: Tests -on: - push: - branches: - - main - pull_request: - branches: - - main - -env: - DISPLAY: ":99" - PYDM_DEFAULT_PROTOCOL: ca - BLUESKY_DEBUG_CALLBACKS: 1 +name: Guarneri +on: [push] jobs: - build-linux: - defaults: - run: - shell: bash -l {0} - runs-on: ubuntu-22.04 + build: + runs-on: ubuntu-latest + timeout-minutes: 10 strategy: - max-parallel: 5 + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: - - uses: actions/checkout@v3 - - name: Install \{ - run: pip install -e ".[dev]" - - name: Environment info - run: | - env - pip freeze - - name: Lint - run: | - # Check for syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # Make sure black code formatting is applied - black --check --preview src/ - # Make sure import orders are correct - isort --check src/ - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + # You can test your matrix by printing the current Python version + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Install guarneri + run: pip install -e ".[dev]" + - name: Environment info + run: | + env + pip freeze + - name: Lint + run: | + # Check for syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # Make sure black code formatting is applied + black --check --preview src/ + # Make sure import orders are correct + isort --check src/ + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Haven tests with pytest in Xvfb + run: pytest -vv diff --git a/pyproject.toml b/pyproject.toml index 0c05353..ad71461 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0"] +requires = ["setuptools>=61.0", "setuptools-scm>=8.0"] build-backend = "setuptools.build_meta" [project]