Skip to content

Commit 4440433

Browse files
authored
Allow context to deal with new style unions and add tests (#436)
Fixes #432
1 parent 4d52871 commit 4440433

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

src/blueapi/core/context.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,8 @@
44
from dataclasses import dataclass, field
55
from importlib import import_module
66
from inspect import Parameter, signature
7-
from types import ModuleType
8-
from typing import (
9-
Any,
10-
Generic,
11-
TypeVar,
12-
get_args,
13-
get_origin,
14-
get_type_hints,
15-
)
7+
from types import ModuleType, UnionType
8+
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints
169

1710
from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop
1811
from pydantic import create_model
@@ -264,7 +257,7 @@ def _type_spec_for_function(
264257
)
265258
return new_args
266259

267-
def _convert_type(self, typ: type) -> type:
260+
def _convert_type(self, typ: type | Any) -> type:
268261
"""
269262
Recursively convert a type to something that can be deserialised by
270263
pydantic. Bluesky protocols (and types that extend them) are replaced
@@ -288,6 +281,8 @@ def _convert_type(self, typ: type) -> type:
288281
if args:
289282
new_types = tuple(self._convert_type(i) for i in args)
290283
root = get_origin(typ)
284+
if root == UnionType:
285+
root = Union
291286
return root[new_types] if root else typ
292287
return typ
293288

tests/core/test_context.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import Union
34
from unittest.mock import patch
45

56
import pytest
@@ -274,6 +275,24 @@ def test_reference_type_conversion(empty_context: BlueskyContext) -> None:
274275
)
275276

276277

278+
def test_reference_type_conversion_union(empty_context: BlueskyContext) -> None:
279+
movable_ref: type = empty_context._reference(Movable)
280+
assert empty_context._convert_type(Movable) == movable_ref
281+
assert (
282+
empty_context._convert_type(Union[Movable, int]) == Union[movable_ref, int] # noqa # type: ignore
283+
)
284+
285+
286+
def test_reference_type_conversion_new_style_union(
287+
empty_context: BlueskyContext,
288+
) -> None:
289+
movable_ref: type = empty_context._reference(Movable)
290+
assert empty_context._convert_type(Movable) == movable_ref
291+
assert (
292+
empty_context._convert_type(Movable | int) == movable_ref | int # type: ignore
293+
)
294+
295+
277296
def test_default_device_reference(empty_context: BlueskyContext) -> None:
278297
def default_movable(mov: Movable = "demo") -> MsgGenerator: # type: ignore
279298
...

0 commit comments

Comments
 (0)