Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to bps.collect_while_completeing #536

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ Fixes #ISSUE

### Checks for reviewer
- [ ] Would the PR title make sense to a user on a set of release notes
- [ ] If the change requires a bump in an IOC version, is that specified in a `##Changes` section in the body of the PR
- [ ] If the change requires a bump in the PandABlocks-ioc version, is the `ophyd_async.fastcs.panda._hdf_panda.MINIMUM_PANDA_IOC` variable updated to match
1 change: 1 addition & 0 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
pip-install: ".[dev]"

- name: Run tests
run: tox -e tests

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
matrix:
runs-on: ["ubuntu-latest", "windows-latest"] # can add macos-latest
python-version: ["3.10","3.11"] # 3.12 should be added when p4p is updated
python-version: ["3.10", "3.11"] # 3.12 should be added when p4p is updated
include:
# Include one that runs in the dev environment
- runs-on: "ubuntu-latest"
Expand Down
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ repos:
entry: ruff format --force-exclude
types: [python]
require_serial: true

- id: import-contracts
name: Ensure import directionality
pass_filenames: false
language: system
entry: lint-imports
types: [python]
require_serial: false
51 changes: 51 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dev = [
"inflection",
"ipython",
"ipywidgets",
"import-linter",
"matplotlib",
"myst-parser",
"numpydoc",
Expand Down Expand Up @@ -164,3 +165,53 @@ lint.preview = true # so that preview mode PLC2701 is enabled
# See https://github.com/DiamondLightSource/python-copier-template/issues/154
# Remove this line to forbid private member access in tests
"tests/**/*" = ["SLF001"]


[tool.importlinter]
root_package = "ophyd_async"

[[tool.importlinter.contracts]]
name = "Core is independent"
type = "independence"
modules = "ophyd_async.core"

[[tool.importlinter.contracts]]
name = "Epics depends only on core"
type = "forbidden"
source_modules = "ophyd_async.epics"
forbidden_modules = [
"ophyd_async.fastcs",
"ophyd_async.plan_stubs",
"ophyd_async.sim",
"ophyd_async.tango",
]

[[tool.importlinter.contracts]]
name = "tango depends only on core"
type = "forbidden"
source_modules = "ophyd_async.tango"
forbidden_modules = [
"ophyd_async.epics",
"ophyd_async.fastcs",
"ophyd_async.plan_stubs",
"ophyd_async.sim",
]


[[tool.importlinter.contracts]]
name = "sim depends only on core"
type = "forbidden"
source_modules = "ophyd_async.sim"
forbidden_modules = [
"ophyd_async.epics",
"ophyd_async.fastcs",
"ophyd_async.plan_stubs",
"ophyd_async.tango",
]


[[tool.importlinter.contracts]]
name = "Fastcs depends only on core, epics, tango"
type = "forbidden"
source_modules = "ophyd_async.fastcs"
forbidden_modules = ["ophyd_async.plan_stubs", "ophyd_async.sim"]
18 changes: 9 additions & 9 deletions src/ophyd_async/core/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class DetectorTrigger(StrictEnum):
"""Type of mechanism for triggering a detector to take frames"""

#: Detector generates internal trigger for given rate
internal = "internal"
INTERNAL = "internal"
#: Expect a series of arbitrary length trigger signals
edge_trigger = "edge_trigger"
EDGE_TRIGGER = "edge_trigger"
#: Expect a series of constant width external gate signals
constant_gate = "constant_gate"
CONSTANT_GATE = "constant_gate"
#: Expect a series of variable width external gate signals
variable_gate = "variable_gate"
VARIABLE_GATE = "variable_gate"


class TriggerInfo(BaseModel):
Expand All @@ -53,7 +53,7 @@ class TriggerInfo(BaseModel):
#: - 3 times for final flat field images
number_of_triggers: NonNegativeInt | list[NonNegativeInt]
#: Sort of triggers that will be sent
trigger: DetectorTrigger = Field(default=DetectorTrigger.internal)
trigger: DetectorTrigger = Field(default=DetectorTrigger.INTERNAL)
#: What is the minimum deadtime between triggers
deadtime: float | None = Field(default=None, ge=0)
#: What is the maximum high time of the triggers
Expand Down Expand Up @@ -265,14 +265,14 @@ async def trigger(self) -> None:
await self.prepare(
TriggerInfo(
number_of_triggers=1,
trigger=DetectorTrigger.internal,
trigger=DetectorTrigger.INTERNAL,
deadtime=None,
livetime=None,
frame_timeout=None,
)
)
assert self._trigger_info
assert self._trigger_info.trigger is DetectorTrigger.internal
assert self._trigger_info.trigger is DetectorTrigger.INTERNAL
# Arm the detector and wait for it to finish.
indices_written = await self.writer.get_indices_written()
await self.controller.arm()
Expand Down Expand Up @@ -303,7 +303,7 @@ async def prepare(self, value: TriggerInfo) -> None:
Args:
value: TriggerInfo describing how to trigger the detector
"""
if value.trigger != DetectorTrigger.internal:
if value.trigger != DetectorTrigger.INTERNAL:
assert (
value.deadtime
), "Deadtime must be supplied when in externally triggered mode"
Expand All @@ -323,7 +323,7 @@ async def prepare(self, value: TriggerInfo) -> None:
self._describe, _ = await asyncio.gather(
self.writer.open(value.multiplier), self.controller.prepare(value)
)
if value.trigger != DetectorTrigger.internal:
if value.trigger != DetectorTrigger.INTERNAL:
await self.controller.arm()
self._fly_start = time.monotonic()

Expand Down
32 changes: 26 additions & 6 deletions src/ophyd_async/core/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ class Device(HasName, Connectable):
_connect_task: asyncio.Task | None = None
# The mock if we have connected in mock mode
_mock: LazyMock | None = None
# The separator to use when making child names
_child_name_separator: str = "-"

def __init__(
self, name: str = "", connector: DeviceConnector | None = None
) -> None:
self._connector = connector or DeviceConnector()
self._connector.create_children_from_annotations(self)
self.set_name(name)
if name:
self.set_name(name)

@property
def name(self) -> str:
Expand All @@ -97,21 +100,30 @@ def log(self) -> LoggerAdapter:
getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name}
)

def set_name(self, name: str):
def set_name(self, name: str, *, child_name_separator: str | None = None) -> None:
"""Set ``self.name=name`` and each ``self.child.name=name+"-child"``.

Parameters
----------
name:
New name to set
child_name_separator:
Use this as a separator instead of "-". Use "_" instead to make the same
names as the equivalent ophyd sync device.
"""
self._name = name
if child_name_separator:
self._child_name_separator = child_name_separator
# Ensure logger is recreated after a name change
if "log" in self.__dict__:
del self.log
for child_name, child in self.children():
child_name = f"{self.name}-{child_name.strip('_')}" if self.name else ""
child.set_name(child_name)
for attr_name, child in self.children():
child_name = (
f"{self.name}{self._child_name_separator}{attr_name}"
if self.name
else ""
)
child.set_name(child_name, child_name_separator=self._child_name_separator)

def __setattr__(self, name: str, value: Any) -> None:
# Bear in mind that this function is called *a lot*, so
Expand Down Expand Up @@ -147,6 +159,10 @@ async def connect(
timeout:
Time to wait before failing with a TimeoutError.
"""
assert hasattr(self, "_connector"), (
f"{self}: doesn't have attribute `_connector`,"
" did you call `super().__init__` in your `__init__` method?"
)
if mock:
# Always connect in mock mode serially
if isinstance(mock, LazyMock):
Expand Down Expand Up @@ -247,6 +263,8 @@ class DeviceCollector:
set_name:
If True, call ``device.set_name(variable_name)`` on all collected
Devices
child_name_separator:
Use this as a separator if we call ``set_name``.
connect:
If True, call ``device.connect(mock)`` in parallel on all
collected Devices
Expand All @@ -271,11 +289,13 @@ class DeviceCollector:
def __init__(
self,
set_name=True,
child_name_separator: str = "-",
connect=True,
mock=False,
timeout: float = 10.0,
):
self._set_name = set_name
self._child_name_separator = child_name_separator
self._connect = connect
self._mock = mock
self._timeout = timeout
Expand Down Expand Up @@ -311,7 +331,7 @@ async def _on_exit(self) -> None:
for name, obj in self._objects_on_exit.items():
if name not in self._names_on_enter and isinstance(obj, Device):
if self._set_name and not obj.name:
obj.set_name(name)
obj.set_name(name, child_name_separator=self._child_name_separator)
if self._connect:
connect_coroutines[name] = obj.connect(
self._mock, timeout=self._timeout
Expand Down
6 changes: 3 additions & 3 deletions src/ophyd_async/core/_mock_signal_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Awaitable, Callable, Iterable
from contextlib import asynccontextmanager, contextmanager
from contextlib import contextmanager
from unittest.mock import AsyncMock, Mock

from ._device import Device
Expand Down Expand Up @@ -40,8 +40,8 @@ def set_mock_put_proceeds(signal: Signal, proceeds: bool):
backend.put_proceeds.clear()


@asynccontextmanager
async def mock_puts_blocked(*signals: Signal):
@contextmanager
def mock_puts_blocked(*signals: Signal):
for signal in signals:
set_mock_put_proceeds(signal, False)
yield
Expand Down
49 changes: 36 additions & 13 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import functools
import time
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from typing import Any, Generic, cast

Expand Down Expand Up @@ -122,7 +123,7 @@ async def get_value(self) -> SignalDatatypeT:

def _callback(self, reading: Reading[SignalDatatypeT]):
self._signal.log.debug(
f"Updated subscription: reading of source {self._signal.source} changed"
f"Updated subscription: reading of source {self._signal.source} changed "
f"from {self._reading} to {reading}"
)
self._reading = reading
Expand Down Expand Up @@ -425,6 +426,7 @@ async def observe_value(
signal: SignalR[SignalDatatypeT],
timeout: float | None = None,
done_status: Status | None = None,
done_timeout: float | None = None,
) -> AsyncGenerator[SignalDatatypeT, None]:
"""Subscribe to the value of a signal so it can be iterated from.

Expand All @@ -439,25 +441,44 @@ async def observe_value(
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.
done_timeout:
If given, the maximum time to watch a signal, in seconds. If the loop is still
being watched after this length, raise asyncio.TimeoutError. This should be used
instead of on an 'asyncio.wait_for' timeout

Notes
-----
Due to a rare condition with busy signals, it is not recommended to use this
function with asyncio.timeout, including in an 'asyncio.wait_for' loop. Instead,
this timeout should be given to the done_timeout parameter.

Example usage::

async for value in observe_value(sig):
do_something_with(value)
"""

async for _, value in observe_signals_value(
signal, timeout=timeout, done_status=done_status
signal,
timeout=timeout,
done_status=done_status,
done_timeout=done_timeout,
):
yield value


def _get_iteration_timeout(
timeout: float | None, overall_deadline: float | None
) -> float | None:
overall_deadline = overall_deadline - time.monotonic() if overall_deadline else None
return min([x for x in [overall_deadline, timeout] if x is not None], default=None)


async def observe_signals_value(
*signals: SignalR[SignalDatatypeT],
timeout: float | None = None,
done_status: Status | None = None,
done_timeout: float | None = None,
) -> AsyncGenerator[tuple[SignalR[SignalDatatypeT], SignalDatatypeT], None]:
"""Subscribe to the value of a signal so it can be iterated from.

Expand All @@ -472,6 +493,10 @@ async def observe_signals_value(
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.
done_timeout:
If given, the maximum time to watch a signal, in seconds. If the loop is still
being watched after this length, raise asyncio.TimeoutError. This should be used
instead of on an 'asyncio.wait_for' timeout

Notes
-----
Expand All @@ -486,12 +511,6 @@ async def observe_signals_value(
q: asyncio.Queue[tuple[SignalR[SignalDatatypeT], SignalDatatypeT] | Status] = (
asyncio.Queue()
)
if timeout is None:
get_value = q.get
else:

async def get_value():
return await asyncio.wait_for(q.get(), timeout)

cbs: dict[SignalR, Callback] = {}
for signal in signals:
Expand All @@ -504,13 +523,17 @@ def queue_value(value: SignalDatatypeT, signal=signal):

if done_status is not None:
done_status.add_callback(q.put_nowait)

overall_deadline = time.monotonic() + done_timeout if done_timeout else None
try:
while True:
# yield here in case something else is filling the queue
# like in test_observe_value_times_out_with_no_external_task()
await asyncio.sleep(0)
item = await get_value()
if overall_deadline and time.monotonic() >= overall_deadline:
raise asyncio.TimeoutError(
f"observe_value was still observing signals "
f"{[signal.source for signal in signals]} after "
f"timeout {done_timeout}s"
)
iteration_timeout = _get_iteration_timeout(timeout, overall_deadline)
item = await asyncio.wait_for(q.get(), iteration_timeout)
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
Expand Down
Loading
Loading