Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(typing): Resolve mypy==1.11.0 issues in plugin_registry #3487

Merged
merged 2 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion altair/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
VegaLiteCompilerType = Callable[[dict], dict]


class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType]):
class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType, dict]):
pass
18 changes: 8 additions & 10 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
Dict,
overload,
runtime_checkable,
Callable,
)
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, ParamSpec, Concatenate
from pathlib import Path
from functools import partial
import sys
Expand Down Expand Up @@ -82,17 +83,14 @@ def is_data_type(obj: Any) -> TypeIs[DataType]:
# VegaLite spec, after the Data model has been put into a schema compliant
# form.
# ==============================================================================
class DataTransformerType(Protocol):
@overload
def __call__(self, data: None = None, **kwargs) -> DataTransformerType: ...
@overload
def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict: ...
def __call__(
self, data: DataType | None = None, **kwargs
) -> DataTransformerType | VegaLiteDataDict: ...

P = ParamSpec("P")
# NOTE: `Any` required due to the complexity of existing signatures imported in `altair.vegalite.v5.data.py`
R = TypeVar("R", VegaLiteDataDict, Any)
DataTransformerType = Callable[Concatenate[DataType, P], R]
binste marked this conversation as resolved.
Show resolved Hide resolved

class DataTransformerRegistry(PluginRegistry[DataTransformerType]):

class DataTransformerRegistry(PluginRegistry[DataTransformerType, R]):
_global_settings = {"consolidate_datasets": True}

@property
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
]


class RendererRegistry(PluginRegistry[RendererType]):
class RendererRegistry(PluginRegistry[RendererType, MimeBundleType]):
entrypoint_err_messages = {
"notebook": textwrap.dedent(
"""
Expand Down
95 changes: 67 additions & 28 deletions altair/utils/plugin_registry.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
from __future__ import annotations

from functools import partial
from typing import Any, Generic, TypeVar, cast, Callable, TYPE_CHECKING
from typing import Any, Generic, cast, Callable, TYPE_CHECKING
from typing_extensions import TypeAliasType, TypeVar, TypeIs

from importlib.metadata import entry_points

from altair.utils.deprecation import deprecated_warn

if TYPE_CHECKING:
from types import TracebackType

T = TypeVar("T")
R = TypeVar("R")
Plugin = TypeAliasType("Plugin", Callable[..., R], type_params=(R,))
PluginT = TypeVar("PluginT", bound=Plugin[Any])
IsPlugin = Callable[[object], TypeIs[Plugin[Any]]]
Copy link
Member Author

@dangotbanned dangotbanned Jul 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an example of how this could be used.

@runtime_checkable would not be useful as it doesn't check types. Any callable would pass, regardless of the return type.

alt.utils.theme.py

from .plugin_registry import PluginRegistry
from typing import Callable
from typing_extensions import TypeAlias, TypeIs

ThemeType: TypeAlias = Callable[..., dict]

def is_theme_plugin(obj: Callable[..., Any]) -> TypeIs[ThemeType]:
    from inspect import signature
    from typing import get_origin

    sig = signature(obj)
    ret = sig.return_annotation
    return ret is dict or get_origin(ret) is dict

class ThemeRegistry(PluginRegistry[ThemeType, dict]):
    pass

alt.vegalite.v5.theme.py

from typing import Final
from ...utils.theme import ThemeRegistry, is_theme_plugin

ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.theme"
themes = ThemeRegistry(entry_point_group=ENTRY_POINT_GROUP, plugin_type=is_theme_plugin)



def _is_type(tp: type[T], /) -> Callable[[object], TypeIs[type[T]]]:
"""Converts a type to guard function.

Added for compatibility with original `PluginRegistry` default.
"""

def func(obj: object, /) -> TypeIs[type[T]]:
return isinstance(obj, tp)

PluginType = TypeVar("PluginType")
return func


class NoSuchEntryPoint(Exception):
Expand Down Expand Up @@ -49,7 +67,7 @@ def __repr__(self) -> str:
return f"{self.registry.__class__.__name__}.enable({self.name!r})"


class PluginRegistry(Generic[PluginType]):
class PluginRegistry(Generic[PluginT, R]):
"""A registry for plugins.

This is a plugin registry that allows plugins to be loaded/registered
Expand All @@ -74,26 +92,44 @@ class PluginRegistry(Generic[PluginType]):
# in the registry rather than passed to the plugins
_global_settings: dict[str, Any] = {}

def __init__(self, entry_point_group: str = "", plugin_type: type = Callable): # type: ignore[assignment]
def __init__(
self, entry_point_group: str = "", plugin_type: IsPlugin = callable
) -> None:
"""Create a PluginRegistry for a named entry point group.

Parameters
==========
entry_point_group: str
The name of the entry point group.
plugin_type: object
A type that will optionally be used for runtime type checking of
loaded plugins using isinstance.
plugin_type
A type narrowing function that will optionally be used for runtime
type checking loaded plugins.

References
==========
https://typing.readthedocs.io/en/latest/spec/narrowing.html
"""
self.entry_point_group: str = entry_point_group
self.plugin_type: type[Any] = plugin_type
self._active: PluginType | None = None
self.plugin_type: IsPlugin
if plugin_type is not callable and isinstance(plugin_type, type):
msg = (
f"Pass a callable `TypeIs` function to `plugin_type` instead.\n"
f"{type(self).__name__!r}(plugin_type)\n\n"
f"See also:\n"
f"https://typing.readthedocs.io/en/latest/spec/narrowing.html\n"
f"https://docs.astral.sh/ruff/rules/assert/"
)
deprecated_warn(msg, version="5.4.0")
self.plugin_type = cast(IsPlugin, _is_type(plugin_type))
else:
self.plugin_type = plugin_type
self._active: Plugin[R] | None = None
self._active_name: str = ""
self._plugins: dict[str, PluginType] = {}
self._plugins: dict[str, PluginT] = {}
self._options: dict[str, Any] = {}
self._global_settings: dict[str, Any] = self.__class__._global_settings.copy()

def register(self, name: str, value: PluginType | Any | None) -> PluginType | None:
def register(self, name: str, value: PluginT | None) -> PluginT | None:
"""Register a plugin by name and value.

This method is used for explicit registration of a plugin and shouldn't be
Expand All @@ -113,12 +149,12 @@ def register(self, name: str, value: PluginType | Any | None) -> PluginType | No
"""
if value is None:
return self._plugins.pop(name, None)
else:
assert isinstance(
value, self.plugin_type
) # Should ideally be fixed by better annotating plugin_type
elif self.plugin_type(value):
self._plugins[name] = value
return value
else:
msg = f"{type(value).__name__!r} is not compatible with {type(self).__name__!r}"
raise TypeError(msg)

def names(self) -> list[str]:
"""List the names of the registered and entry points plugins."""
Expand Down Expand Up @@ -163,7 +199,7 @@ def _enable(self, name: str, **options) -> None:
raise ValueError(self.entrypoint_err_messages[name]) from err
else:
raise NoSuchEntryPoint(self.entry_point_group, name) from err
value = cast(PluginType, ep.load())
value = cast(PluginT, ep.load())
self.register(name, value)
self._active_name = name
self._active = self._plugins[name]
Expand Down Expand Up @@ -204,18 +240,21 @@ def options(self) -> dict[str, Any]:
"""Return the current options dictionary"""
return self._options

def get(self) -> PluginType | Callable[..., Any] | None:
def get(self) -> partial[R] | Plugin[R] | None:
"""Return the currently active plugin."""
if self._options:
if func := self._active:
# NOTE: Fully do not understand this one
# error: Argument 1 to "partial" has incompatible type "PluginType"; expected "Callable[..., Never]"
return partial(func, **self._options) # type: ignore[arg-type]
else:
msg = "Unclear what this meant by passing to curry."
raise TypeError(msg)
else:
return self._active
if (func := self._active) and self.plugin_type(func):
return partial(func, **self._options) if self._options else func
elif self._active is not None:
msg = (
f"{type(self).__name__!r} requires all plugins to be callable objects, "
f"but {type(self._active).__name__!r} is not callable."
)
raise TypeError(msg)
elif TYPE_CHECKING:
# NOTE: The `None` return is implicit, but `mypy` isn't satisfied
# - `ruff` will factor out explicit `None` return
# - `pyright` has no issue
raise NotImplementedError

def __repr__(self) -> str:
return f"{type(self).__name__}(active={self.active!r}, registered={self.names()!r})"
Expand All @@ -228,6 +267,6 @@ def importlib_metadata_get(group):
# also get compatibility with the importlib_metadata package which had a different
# deprecation cycle for 'get'
if hasattr(ep, "select"):
return ep.select(group=group)
return ep.select(group=group) # pyright: ignore
else:
return ep.get(group, [])
2 changes: 1 addition & 1 deletion altair/utils/theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
ThemeType = Callable[..., dict]


class ThemeRegistry(PluginRegistry[ThemeType]):
class ThemeRegistry(PluginRegistry[ThemeType, dict]):
pass
3 changes: 2 additions & 1 deletion altair/vegalite/v5/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
data_transformers = DataTransformerRegistry(entry_point_group=ENTRY_POINT_GROUP)
data_transformers.register("default", default_data_transformer)
data_transformers.register("json", to_json)
data_transformers.register("csv", to_csv)
# FIXME: `to_csv` cannot accept all `DataType` https://github.com/vega/altair/issues/3441
data_transformers.register("csv", to_csv) # type: ignore[arg-type]
data_transformers.register("vegafusion", vegafusion_data_transformer)
data_transformers.enable("default")

Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable


class TypedCallableRegistry(PluginRegistry[Callable[[int], int]]):
class TypedCallableRegistry(PluginRegistry[Callable[[int], int], int]):
pass


Expand Down