Skip to content

Commit c159638

Browse files
feat: make search results more ergonomic (#498)
1 parent a8fdae8 commit c159638

File tree

7 files changed

+87
-57
lines changed

7 files changed

+87
-57
lines changed

tagstudio/src/core/library/alchemy/library.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import dataclass
12
from datetime import datetime, UTC
23
import shutil
34
from os import makedirs
@@ -87,6 +88,33 @@ def get_default_tags() -> tuple[Tag, ...]:
8788
return archive_tag, favorite_tag
8889

8990

91+
@dataclass(frozen=True)
92+
class SearchResult:
93+
"""Wrapper for search results.
94+
95+
:param total_count: total number of items for given query, might be different than len(items)
96+
:param items: items for current page (size matches filter.page_size)
97+
"""
98+
99+
total_count: int
100+
items: list[Entry]
101+
102+
def __bool__(self) -> bool:
103+
"""Boolean evaluation for the wrapper.
104+
105+
:return: True if there are items in the result.
106+
"""
107+
return self.total_count > 0
108+
109+
def __len__(self) -> int:
110+
"""Return the total number of items in the result."""
111+
return len(self.items)
112+
113+
def __getitem__(self, index: int) -> Entry:
114+
"""Allow to access items via index directly on the wrapper."""
115+
return self.items[index]
116+
117+
90118
class Library:
91119
"""Class for the Library object, and all CRUD operations made upon it."""
92120

@@ -325,7 +353,7 @@ def has_path_entry(self, path: Path) -> bool:
325353
def search_library(
326354
self,
327355
search: FilterState,
328-
) -> tuple[int, list[Entry]]:
356+
) -> SearchResult:
329357
"""Filter library by search query.
330358
331359
:return: number of entries matching the query and one page of results.
@@ -401,11 +429,14 @@ def search_library(
401429
),
402430
)
403431

404-
entries_ = list(session.scalars(statement).unique())
432+
res = SearchResult(
433+
total_count=count_all,
434+
items=list(session.scalars(statement).unique()),
435+
)
405436

406437
session.expunge_all()
407438

408-
return count_all, entries_
439+
return res
409440

410441
def search_tags(
411442
self,

tagstudio/src/core/utils/dupe_files.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def refresh_dupe_files(self, results_filepath: str | Path):
5050
# The file is not in the library directory
5151
continue
5252

53-
_, entries = self.library.search_library(
53+
results = self.library.search_library(
5454
FilterState(path=path_relative),
5555
)
5656

57-
if not entries:
57+
if not results:
5858
# file not in library
5959
continue
6060

61-
files.append(entries[0])
61+
files.append(results[0])
6262

6363
if not len(files) > 1:
6464
# only one file in the group, nothing to do

tagstudio/src/qt/ts_qt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,26 +1009,26 @@ def filter_items(self, filter: FilterState | None = None) -> None:
10091009
self.main_window.statusbar.repaint()
10101010
start_time = time.time()
10111011

1012-
query_count, page_items = self.lib.search_library(self.filter)
1012+
results = self.lib.search_library(self.filter)
10131013

1014-
logger.info("items to render", count=len(page_items))
1014+
logger.info("items to render", count=len(results))
10151015

10161016
end_time = time.time()
10171017
if self.filter.summary:
10181018
self.main_window.statusbar.showMessage(
1019-
f'{query_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
1019+
f'{results.total_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
10201020
)
10211021
else:
10221022
self.main_window.statusbar.showMessage(
1023-
f"{query_count} Results ({format_timespan(end_time - start_time)})"
1023+
f"{results.total_count} Results ({format_timespan(end_time - start_time)})"
10241024
)
10251025

10261026
# update page content
1027-
self.frame_content = list(page_items)
1027+
self.frame_content = results.items
10281028
self.update_thumbs()
10291029

10301030
# update pagination
1031-
self.pages_count = math.ceil(query_count / self.filter.page_size)
1031+
self.pages_count = math.ceil(results.total_count / self.filter.page_size)
10321032
self.main_window.pagination.update_buttons(
10331033
self.pages_count, self.filter.page_index, emit=False
10341034
)

tagstudio/src/qt/widgets/item_thumb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def on_badge_check(self, badge_type: BadgeType):
487487
# update the entry
488488
self.driver.frame_content[idx] = self.lib.search_library(
489489
FilterState(id=entry.id)
490-
)[1][0]
490+
).items[0]
491491

492492
self.driver.update_badges(update_items)
493493

tagstudio/src/qt/widgets/preview_panel.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def update_selected_entry(driver: "QtDriver"):
6363
for grid_idx in driver.selected:
6464
entry = driver.frame_content[grid_idx]
6565
# reload entry
66-
_, entries = driver.lib.search_library(FilterState(id=entry.id))
66+
results = driver.lib.search_library(FilterState(id=entry.id))
6767
logger.info(
68-
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
68+
"found item", entries=len(results), grid_idx=grid_idx, lookup_id=entry.id
6969
)
70-
assert entries, f"Entry not found: {entry.id}"
71-
driver.frame_content[grid_idx] = entries[0]
70+
assert results, f"Entry not found: {entry.id}"
71+
driver.frame_content[grid_idx] = next(results)
7272

7373

7474
class PreviewPanel(QWidget):
@@ -499,11 +499,14 @@ def update_widgets(self) -> bool:
499499
# TODO - Entry reload is maybe not necessary
500500
for grid_idx in self.driver.selected:
501501
entry = self.driver.frame_content[grid_idx]
502-
_, entries = self.lib.search_library(FilterState(id=entry.id))
502+
results = self.lib.search_library(FilterState(id=entry.id))
503503
logger.info(
504-
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
504+
"found item",
505+
entries=len(results.items),
506+
grid_idx=grid_idx,
507+
lookup_id=entry.id,
505508
)
506-
self.driver.frame_content[grid_idx] = entries[0]
509+
self.driver.frame_content[grid_idx] = results[0]
507510

508511
if len(self.driver.selected) == 1:
509512
# 1 Selected Entry

tagstudio/tests/macros/test_missing_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ def test_refresh_missing_files(library: Library):
2727
assert list(registry.fix_missing_files()) == [1, 2]
2828

2929
# `bar.md` should be relinked to new correct path
30-
_, entries = library.search_library(FilterState(path="bar.md"))
31-
assert entries[0].path == pathlib.Path("bar.md")
30+
results = library.search_library(FilterState(path="bar.md"))
31+
assert results[0].path == pathlib.Path("bar.md")

tagstudio/tests/test_library.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ def test_library_search(library, generate_tag, entry_full):
5757
assert library.entries_count == 2
5858
tag = list(entry_full.tags)[0]
5959

60-
query_count, items = library.search_library(
60+
results = library.search_library(
6161
FilterState(
6262
tag=tag.name,
6363
),
6464
)
6565

66-
assert query_count == 1
67-
assert len(items) == 1
66+
assert results.total_count == 1
67+
assert len(results) == 1
6868

69-
entry = items[0]
69+
entry = results[0]
7070
assert {x.name for x in entry.tags} == {
7171
"foo",
7272
}
@@ -94,9 +94,9 @@ def test_tag_search(library):
9494

9595
def test_get_entry(library, entry_min):
9696
assert entry_min.id
97-
cnt, entries = library.search_library(FilterState(id=entry_min.id))
98-
assert len(entries) == cnt == 1
99-
assert entries[0].tags
97+
results = library.search_library(FilterState(id=entry_min.id))
98+
assert len(results) == results.total_count == 1
99+
assert results[0].tags
100100

101101

102102
def test_entries_count(library):
@@ -105,14 +105,14 @@ def test_entries_count(library):
105105
for x in range(10)
106106
]
107107
library.add_entries(entries)
108-
matches, page = library.search_library(
108+
results = library.search_library(
109109
FilterState(
110110
page_size=5,
111111
)
112112
)
113113

114-
assert matches == 12
115-
assert len(page) == 5
114+
assert results.total_count == 12
115+
assert len(results) == 5
116116

117117

118118
def test_add_field_to_entry(library):
@@ -146,8 +146,8 @@ def test_add_field_tag(library, entry_full, generate_tag):
146146
library.add_field_tag(entry_full, tag, tag_field.type_key)
147147

148148
# Then
149-
_, entries = library.search_library(FilterState(id=entry_full.id))
150-
tag_field = entries[0].tag_box_fields[0]
149+
results = library.search_library(FilterState(id=entry_full.id))
150+
tag_field = results[0].tag_box_fields[0]
151151
assert [x.name for x in tag_field.tags if x.name == tag_name]
152152

153153

@@ -179,15 +179,15 @@ def test_search_filter_extensions(library, is_exclude):
179179
library.set_prefs(LibraryPrefs.EXTENSION_LIST, ["md"])
180180

181181
# When
182-
query_count, items = library.search_library(
182+
results = library.search_library(
183183
FilterState(),
184184
)
185185

186186
# Then
187-
assert query_count == 1
188-
assert len(items) == 1
187+
assert results.total_count == 1
188+
assert len(results) == 1
189189

190-
entry = items[0]
190+
entry = results[0]
191191
assert (entry.path.suffix == ".txt") == is_exclude
192192

193193

@@ -200,15 +200,15 @@ def test_search_library_case_insensitive(library):
200200
tag = list(entry.tags)[0]
201201

202202
# When
203-
query_count, items = library.search_library(
203+
results = library.search_library(
204204
FilterState(tag=tag.name.upper()),
205205
)
206206

207207
# Then
208-
assert query_count == 1
209-
assert len(items) == 1
208+
assert results.total_count == 1
209+
assert len(results) == 1
210210

211-
assert items[0].id == entry.id
211+
assert results[0].id == entry.id
212212

213213

214214
def test_preferences(library):
@@ -231,11 +231,11 @@ def test_save_windows_path(library, generate_tag):
231231
# library.add_tag(tag)
232232
library.add_field_tag(entry, tag, create_field=True)
233233

234-
_, found = library.search_library(FilterState(tag=tag_name))
235-
assert found
234+
results = library.search_library(FilterState(tag=tag_name))
235+
assert results
236236

237237
# path should be saved in posix format
238-
assert str(found[0].path) == "foo/bar.txt"
238+
assert str(results[0].path) == "foo/bar.txt"
239239

240240

241241
def test_remove_entry_field(library, entry_full):
@@ -312,13 +312,13 @@ def test_mirror_entry_fields(library, entry_full):
312312

313313
entry_id = library.add_entries([target_entry])[0]
314314

315-
_, entries = library.search_library(FilterState(id=entry_id))
316-
new_entry = entries[0]
315+
results = library.search_library(FilterState(id=entry_id))
316+
new_entry = results[0]
317317

318318
library.mirror_entry_fields(new_entry, entry_full)
319319

320-
_, entries = library.search_library(FilterState(id=entry_id))
321-
entry = entries[0]
320+
results = library.search_library(FilterState(id=entry_id))
321+
entry = results[0]
322322

323323
assert len(entry.fields) == 4
324324
assert {x.type_key for x in entry.fields} == {
@@ -350,13 +350,11 @@ def test_remove_tag_from_field(library, entry_full):
350350
],
351351
)
352352
def test_search_file_name(library, query_name, has_result):
353-
res_count, items = library.search_library(
353+
results = library.search_library(
354354
FilterState(name=query_name),
355355
)
356356

357-
assert (
358-
res_count == has_result
359-
), f"mismatch with query: {query_name}, result: {res_count}"
357+
assert results.total_count == has_result
360358

361359

362360
@pytest.mark.parametrize(
@@ -369,13 +367,11 @@ def test_search_file_name(library, query_name, has_result):
369367
],
370368
)
371369
def test_search_entry_id(library, query_name, has_result):
372-
res_count, items = library.search_library(
370+
results = library.search_library(
373371
FilterState(id=query_name),
374372
)
375373

376-
assert (
377-
res_count == has_result
378-
), f"mismatch with query: {query_name}, result: {res_count}"
374+
assert results.total_count == has_result
379375

380376

381377
def test_update_field_order(library, entry_full):

0 commit comments

Comments
 (0)