Skip to content

Commit 26856c4

Browse files
committed
fix test failure
1 parent c762109 commit 26856c4

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

array_api_compat/common/_helpers.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import inspect
1414
import warnings
1515
from collections.abc import Generator
16-
from types import ModuleType
1716
from typing import Optional, Union, Any
1817

1918
from ._typing import Array, Device, Namespace
@@ -801,10 +800,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
801800

802801

803802
def _device_ctx(
804-
bare_xp: ModuleType, device: Device, like: Array | None = None
803+
bare_xp: Namespace, device: Device, like: Array | None = None
805804
) -> Generator[None]:
806805
"""Context manager which changes the current device in CuPy.
807-
806+
808807
Used internally by array creation functions in common._aliases.
809808
"""
810809
if device is None:
@@ -832,7 +831,7 @@ def _device_ctx(
832831
raise AssertionError("unreachable") # pragma: nocover
833832

834833

835-
def _validate_device(bare_xp: ModuleType, device: Device) -> None:
834+
def _validate_device(bare_xp: Namespace, device: Device) -> None:
836835
with _device_ctx(bare_xp, device):
837836
pass
838837

0 commit comments

Comments
 (0)