Skip to content

Commit

Permalink
Merge pull request #4950 from Textualize/faster-query
Browse files Browse the repository at this point in the history
Faster query_one
  • Loading branch information
willmcgugan authored Aug 28, 2024
2 parents 75d71f5 + 78bd0f5 commit 58d25fb
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

- Added `DOMNode.check_consume_key` https://github.com/Textualize/textual/pull/4940
- Added `DOMNode.query_exactly_one` https://github.com/Textualize/textual/pull/4950
- Added `SelectorSet.is_simple` https://github.com/Textualize/textual/pull/4950

### Changed

- KeyPanel will show multiple keys if bound to the same action https://github.com/Textualize/textual/pull/4940
- Breaking change: `DOMNode.query_one` will not `raise TooManyMatches` https://github.com/Textualize/textual/pull/4950

## [0.78.0] - 2024-08-27

Expand Down
1 change: 0 additions & 1 deletion docs/guide/queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ send_button = self.query_one("#send")

This will retrieve a widget with an ID of `send`, if there is exactly one.
If there are no matching widgets, Textual will raise a [NoMatches][textual.css.query.NoMatches] exception.
If there is more than one match, Textual will raise a [TooManyMatches][textual.css.query.TooManyMatches] exception.

You can also add a second parameter for the expected type, which will ensure that you get the type you are expecting.

Expand Down
21 changes: 15 additions & 6 deletions src/textual/_node_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if TYPE_CHECKING:
from _typeshed import SupportsRichComparison

from .dom import DOMNode
from .widget import Widget


Expand All @@ -24,7 +25,8 @@ class NodeList(Sequence["Widget"]):
Although named a list, widgets may appear only once, making them more like a set.
"""

def __init__(self) -> None:
def __init__(self, parent: DOMNode | None = None) -> None:
self._parent = parent
# The nodes in the list
self._nodes: list[Widget] = []
self._nodes_set: set[Widget] = set()
Expand Down Expand Up @@ -52,6 +54,13 @@ def __len__(self) -> int:
def __contains__(self, widget: object) -> bool:
return widget in self._nodes

def updated(self) -> None:
"""Mark the nodes as having been updated."""
self._updates += 1
node = self._parent
while node is not None and (node := node._parent) is not None:
node._nodes._updates += 1

def _sort(
self,
*,
Expand All @@ -69,7 +78,7 @@ def _sort(
else:
self._nodes.sort(key=key, reverse=reverse)

self._updates += 1
self.updated()

def index(self, widget: Any, start: int = 0, stop: int = sys.maxsize) -> int:
"""Return the index of the given widget.
Expand Down Expand Up @@ -102,7 +111,7 @@ def _append(self, widget: Widget) -> None:
if widget_id is not None:
self._ensure_unique_id(widget_id)
self._nodes_by_id[widget_id] = widget
self._updates += 1
self.updated()

def _insert(self, index: int, widget: Widget) -> None:
"""Insert a Widget.
Expand All @@ -117,7 +126,7 @@ def _insert(self, index: int, widget: Widget) -> None:
if widget_id is not None:
self._ensure_unique_id(widget_id)
self._nodes_by_id[widget_id] = widget
self._updates += 1
self.updated()

def _ensure_unique_id(self, widget_id: str) -> None:
if widget_id in self._nodes_by_id:
Expand All @@ -141,15 +150,15 @@ def _remove(self, widget: Widget) -> None:
widget_id = widget.id
if widget_id in self._nodes_by_id:
del self._nodes_by_id[widget_id]
self._updates += 1
self.updated()

def _clear(self) -> None:
"""Clear the node list."""
if self._nodes:
self._nodes.clear()
self._nodes_set.clear()
self._nodes_by_id.clear()
self._updates += 1
self.updated()

def __iter__(self) -> Iterator[Widget]:
return iter(self._nodes)
Expand Down
9 changes: 9 additions & 0 deletions src/textual/css/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ def __post_init__(self) -> None:
def css(self) -> str:
return RuleSet._selector_to_css(self.selectors)

@property
def is_simple(self) -> bool:
"""Are all the selectors simple (i.e. only dependent on static DOM state)."""
simple_types = {SelectorType.ID, SelectorType.TYPE}
return all(
(selector.type in simple_types and not selector.pseudo_classes)
for selector in self.selectors
)

def __rich_repr__(self) -> rich.repr.Result:
selectors = RuleSet._selector_to_css(self.selectors)
yield selectors
Expand Down
117 changes: 108 additions & 9 deletions src/textual/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
from ._node_list import NodeList
from ._types import WatchCallbackType
from .binding import Binding, BindingsMap, BindingType
from .cache import LRUCache
from .color import BLACK, WHITE, Color
from .css._error_tools import friendly_list
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
from .css.errors import DeclarationError, StyleValueError
from .css.parse import parse_declarations
from .css.match import match
from .css.parse import parse_declarations, parse_selectors
from .css.query import NoMatches, TooManyMatches
from .css.styles import RenderStyles, Styles
from .css.tokenize import IDENTIFIER
from .message_pump import MessagePump
Expand All @@ -60,7 +63,7 @@
from .worker import Worker, WorkType, ResultType

# Unused & ignored imports are needed for the docs to link to these objects:
from .css.query import NoMatches, TooManyMatches, WrongType # type: ignore # noqa: F401
from .css.query import WrongType # type: ignore # noqa: F401

from typing_extensions import Literal

Expand All @@ -74,6 +77,10 @@
ReactiveType = TypeVar("ReactiveType")


QueryOneCacheKey: TypeAlias = "tuple[int, str, Type[Widget] | None]"
"""The key used to cache query_one results."""


class BadIdentifier(Exception):
"""Exception raised if you supply a `id` attribute or class name in the wrong format."""

Expand Down Expand Up @@ -184,13 +191,14 @@ def __init__(
self._name = name
self._id = None
if id is not None:
self.id = id
check_identifiers("id", id)
self._id = id

_classes = classes.split() if classes else []
check_identifiers("class name", *_classes)
self._classes.update(_classes)

self._nodes: NodeList = NodeList()
self._nodes: NodeList = NodeList(self)
self._css_styles: Styles = Styles(self)
self._inline_styles: Styles = Styles(self)
self.styles: RenderStyles = RenderStyles(
Expand All @@ -213,6 +221,8 @@ def __init__(
dict[str, tuple[MessagePump, Reactive | object]] | None
) = None
self._pruning = False
self._query_one_cache: LRUCache[QueryOneCacheKey, DOMNode] = LRUCache(1024)

super().__init__()

def set_reactive(
Expand Down Expand Up @@ -741,7 +751,7 @@ def id(self, new_id: str) -> str:
ValueError: If the ID has already been set.
"""
check_identifiers("id", new_id)

self._nodes.updated()
if self._id is not None:
raise ValueError(
f"Node 'id' attribute may not be changed once set (current id={self._id!r})"
Expand Down Expand Up @@ -1393,21 +1403,110 @@ def query_one(
Raises:
WrongType: If the wrong type was found.
NoMatches: If no node matches the query.
TooManyMatches: If there is more than one matching node in the query.
Returns:
A widget matching the selector.
"""
_rich_traceback_omit = True
from .css.query import DOMQuery

if isinstance(selector, str):
query_selector = selector
else:
query_selector = selector.__name__
query: DOMQuery[Widget] = DOMQuery(self, filter=query_selector)

return query.only_one() if expect_type is None else query.only_one(expect_type)
selector_set = parse_selectors(query_selector)

if all(selectors.is_simple for selectors in selector_set):
cache_key = (self._nodes._updates, query_selector, expect_type)
cached_result = self._query_one_cache.get(cache_key)
if cached_result is not None:
return cached_result
else:
cache_key = None

for node in walk_depth_first(self, with_root=False):
if not match(selector_set, node):
continue
if expect_type is not None and not isinstance(node, expect_type):
continue
if cache_key is not None:
self._query_one_cache[cache_key] = node
return node

raise NoMatches(f"No nodes match {selector!r} on {self!r}")

if TYPE_CHECKING:

@overload
def query_exactly_one(self, selector: str) -> Widget: ...

@overload
def query_exactly_one(self, selector: type[QueryType]) -> QueryType: ...

@overload
def query_exactly_one(
self, selector: str, expect_type: type[QueryType]
) -> QueryType: ...

def query_exactly_one(
self,
selector: str | type[QueryType],
expect_type: type[QueryType] | None = None,
) -> QueryType | Widget:
"""Get a widget from this widget's children that matches a selector or widget type.
!!! Note
This method is similar to [query_one][textual.dom.DOMNode.query_one].
The only difference is that it will raise `TooManyMatches` if there is more than a single match.
Args:
selector: A selector or widget type.
expect_type: Require the object be of the supplied type, or None for any type.
Raises:
WrongType: If the wrong type was found.
NoMatches: If no node matches the query.
TooManyMatches: If there is more than one matching node in the query (and `exactly_one==True`).
Returns:
A widget matching the selector.
"""
_rich_traceback_omit = True

if isinstance(selector, str):
query_selector = selector
else:
query_selector = selector.__name__

selector_set = parse_selectors(query_selector)

if all(selectors.is_simple for selectors in selector_set):
cache_key = (self._nodes._updates, query_selector, expect_type)
cached_result = self._query_one_cache.get(cache_key)
if cached_result is not None:
return cached_result
else:
cache_key = None

children = walk_depth_first(self, with_root=False)
iter_children = iter(children)
for node in iter_children:
if not match(selector_set, node):
continue
if expect_type is not None and not isinstance(node, expect_type):
continue
for later_node in iter_children:
if match(selector_set, later_node):
if expect_type is not None and not isinstance(node, expect_type):
continue
raise TooManyMatches(
"Call to query_one resulted in more than one matched node"
)
if cache_key is not None:
self._query_one_cache[cache_key] = node
return node

raise NoMatches(f"No nodes match {selector!r} on {self!r}")

def set_styles(self, css: str | None = None, **update_styles: Any) -> Self:
"""Set custom styles on this object.
Expand Down
26 changes: 9 additions & 17 deletions src/textual/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
from .renderables.blank import Blank
from .rlock import RLock
from .strip import Strip
from .walk import walk_depth_first

if TYPE_CHECKING:
from .app import App, ComposeResult
Expand Down Expand Up @@ -807,21 +806,14 @@ def get_widget_by_id(
NoMatches: if no children could be found for this ID.
WrongType: if the wrong type was found.
"""
# We use Widget as a filter_type so that the inferred type of child is Widget.
for child in walk_depth_first(self, filter_type=Widget):
try:
if expect_type is None:
return child.get_child_by_id(id)
else:
return child.get_child_by_id(id, expect_type=expect_type)
except NoMatches:
pass
except WrongType as exc:
raise WrongType(
f"Descendant with id={id!r} is wrong type; expected {expect_type},"
f" got {type(child)}"
) from exc
raise NoMatches(f"No descendant found with id={id!r}")

widget = self.query_one(f"#{id}")
if expect_type is not None and not isinstance(widget, expect_type):
raise WrongType(
f"Descendant with id={id!r} is wrong type; expected {expect_type},"
f" got {type(widget)}"
)
return widget

def get_child_by_type(self, expect_type: type[ExpectType]) -> ExpectType:
"""Get the first immediate child of a given type.
Expand Down Expand Up @@ -958,7 +950,7 @@ def _find_mount_point(self, spot: int | str | "Widget") -> tuple["Widget", int]:
# can be passed to query_one. So let's use that to get a widget to
# work on.
if isinstance(spot, str):
spot = self.query_one(spot, Widget)
spot = self.query_exactly_one(spot, Widget)

# At this point we should have a widget, either because we got given
# one, or because we pulled one out of the query. First off, does it
Expand Down
2 changes: 1 addition & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class App(Widget):
assert app.query_one("#widget1") == widget1
assert app.query_one("#widget1", Widget) == widget1
with pytest.raises(TooManyMatches):
_ = app.query_one(Widget)
_ = app.query_exactly_one(Widget)

assert app.query("Widget.float")[0] == sidebar
assert app.query("Widget.float")[0:2] == [sidebar, tooltip]
Expand Down

0 comments on commit 58d25fb

Please sign in to comment.