Skip to content

Commit

Permalink
Refactor and Fixes: Type Hints, Import Issues, and Code Formatting
Browse files Browse the repository at this point in the history
- Refactored code to address type hint issues across multiple files.
- Fixed import issues.
- Resolved `mypy` and `pytest` errors.
- Formatted with ruff.
- Removal of few unnecessary lambdas.
  • Loading branch information
VasigaranAndAngel committed Nov 27, 2024
1 parent 7b2636e commit cb94fd5
Show file tree
Hide file tree
Showing 26 changed files with 553 additions and 707 deletions.
6 changes: 3 additions & 3 deletions tagstudio/src/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ class LibraryPrefs(DefaultEnum):
"""Library preferences with default value accessible via .default property."""

IS_EXCLUDE_LIST = True
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
PAGE_SIZE: int = 500
DB_VERSION: int = 2
EXTENSION_LIST = [".json", ".xmp", ".aae"]
PAGE_SIZE = 500
DB_VERSION = 2
62 changes: 23 additions & 39 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,25 @@
import re
import shutil
import unicodedata
from collections.abc import Iterator
from dataclasses import dataclass
from datetime import UTC, datetime
from os import makedirs
from pathlib import Path
from typing import Any, Iterator, Type
from typing import Any, Type
from uuid import uuid4

import structlog
from sqlalchemy import (
URL,
Engine,
and_,
create_engine,
delete,
exists,
func,
or_,
select,
update,
)
from sqlalchemy import URL, Engine, and_, create_engine, delete, exists, func, or_, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import (
Session,
aliased,
contains_eager,
make_transient,
selectinload,
)

from ...constants import (
BACKUP_FOLDER_NAME,
TAG_ARCHIVED,
TAG_FAVORITE,
TS_FOLDER_NAME,
)
from sqlalchemy.orm import Session, aliased, contains_eager, make_transient, selectinload

from ...constants import BACKUP_FOLDER_NAME, TAG_ARCHIVED, TAG_FAVORITE, TS_FOLDER_NAME
from ...enums import LibraryPrefs
from ...media_types import MediaCategories
from .db import make_tables
from .enums import FieldTypeEnum, FilterState, TagColor
from .fields import (
BaseField,
DatetimeField,
TagBoxField,
TextField,
_FieldID,
)
from .fields import BaseField, DatetimeField, TagBoxField, TextField, _FieldID
from .joins import TagField, TagSubtag
from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType

Expand Down Expand Up @@ -210,6 +183,7 @@ def open_library(self, library_dir: Path, storage_path: str | None = None) -> Li
db_version = session.scalar(
select(Preferences).where(Preferences.key == LibraryPrefs.DB_VERSION.name)
)
assert db_version is not None
# if the db version is different, we cant proceed
if db_version.value != LibraryPrefs.DB_VERSION.default:
logger.error(
Expand Down Expand Up @@ -310,7 +284,7 @@ def get_entry(self, entry_id: int) -> Entry | None:
@property
def entries_count(self) -> int:
with Session(self.engine) as session:
return session.scalar(select(func.count(Entry.id)))
return session.scalar(select(func.count(Entry.id))) or 0

def get_entries(self, with_joins: bool = False) -> Iterator[Entry]:
"""Load entries without joins."""
Expand Down Expand Up @@ -484,7 +458,7 @@ def search_library(
)

query_count = select(func.count()).select_from(statement.alias("entries"))
count_all: int = session.execute(query_count).scalar()
count_all: int = session.execute(query_count).scalar() or 0

statement = statement.limit(search.limit).offset(search.offset)

Expand Down Expand Up @@ -683,6 +657,8 @@ def field_types(self) -> dict[str, ValueType]:
def get_value_type(self, field_key: str) -> ValueType:
with Session(self.engine) as session:
field = session.scalar(select(ValueType).where(ValueType.key == field_key))
if field is None:
raise ValueError(f"No field found with key {field_key}.")
session.expunge(field)
return field

Expand All @@ -709,8 +685,10 @@ def add_entry_field_type(

if not field:
if isinstance(field_id, _FieldID):
field_id = field_id.name
field = self.get_value_type(field_id)
_field_id = field_id.name
elif isinstance(field_id, str):
_field_id = field_id
field = self.get_value_type(_field_id)

field_model: TextField | DatetimeField | TagBoxField
if field.type in (FieldTypeEnum.TEXT_LINE, FieldTypeEnum.TEXT_BOX):
Expand Down Expand Up @@ -860,6 +838,7 @@ def get_tag(self, tag_id: int) -> Tag:
with Session(self.engine) as session:
tags_query = select(Tag).options(selectinload(Tag.subtags), selectinload(Tag.aliases))
tag = session.scalar(tags_query.where(Tag.id == tag_id))
assert tag is not None

session.expunge(tag)
for subtag in tag.subtags:
Expand All @@ -875,6 +854,7 @@ def get_alias(self, tag_id: int, alias_id: int) -> TagAlias:
alias_query = select(TagAlias).where(TagAlias.id == alias_id, TagAlias.tag_id == tag_id)
alias = session.scalar(alias_query.where(TagAlias.id == alias_id))

assert alias is not None
return alias

def add_subtag(self, base_id: int, new_tag_id: int) -> bool:
Expand Down Expand Up @@ -957,13 +937,17 @@ def update_subtags(self, tag, subtag_ids, session):
def prefs(self, key: LibraryPrefs) -> Any:
# load given item from Preferences table
with Session(self.engine) as session:
return session.scalar(select(Preferences).where(Preferences.key == key.name)).value
pref = session.scalar(select(Preferences).where(Preferences.key == key.name))
assert pref is not None
return pref.value

def set_prefs(self, key: LibraryPrefs, value: Any) -> None:
# set given item in Preferences table
with Session(self.engine) as session:
# load existing preference and update value
pref = session.scalar(select(Preferences).where(Preferences.key == key.name))
if pref is None:
raise KeyError(f"Preference {key} does not exist")
pref.value = value
session.add(pref)
session.commit()
Expand Down
10 changes: 4 additions & 6 deletions tagstudio/src/qt/flowlayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""PySide6 port of the widgets/layouts/flowlayout example from Qt v6.x."""

from PySide6.QtCore import QMargins, QPoint, QRect, QSize, Qt
from PySide6.QtWidgets import QLayout, QSizePolicy, QWidget
from PySide6.QtWidgets import QLayout, QLayoutItem, QSizePolicy, QWidget


class FlowWidget(QWidget):
Expand All @@ -21,7 +21,7 @@ def __init__(self, parent=None):
if parent is not None:
self.setContentsMargins(QMargins(0, 0, 0, 0))

self._item_list = []
self._item_list: list[QLayoutItem] = []
self.grid_efficiency = False

def __del__(self):
Expand Down Expand Up @@ -88,8 +88,6 @@ def _do_layout(self, rect: QRect, test_only: bool) -> float:
y = rect.y()
line_height = 0
spacing = self.spacing()
layout_spacing_x = None
layout_spacing_y = None

if self.grid_efficiency and self._item_list:
item = self._item_list[0]
Expand All @@ -107,10 +105,10 @@ def _do_layout(self, rect: QRect, test_only: bool) -> float:

for item in self._item_list:
skip_count = 0
if issubclass(type(item.widget()), FlowWidget) and item.widget().ignore_size:
if issubclass(type(item.widget()), FlowWidget) and item.widget().ignore_size: # type: ignore
skip_count += 1

if (issubclass(type(item.widget()), FlowWidget) and not item.widget().ignore_size) or (
if (issubclass(type(item.widget()), FlowWidget) and not item.widget().ignore_size) or ( # type: ignore
not issubclass(type(item.widget()), FlowWidget)
):
if not self.grid_efficiency:
Expand Down
2 changes: 1 addition & 1 deletion tagstudio/src/qt/helpers/color_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def theme_fg_overlay(image: Image.Image, use_alpha: bool = True) -> Image.Image:
return _apply_overlay(image, im)


def gradient_overlay(image: Image.Image, gradient=list[str]) -> Image.Image:
def gradient_overlay(image: Image.Image, gradient: list[str]) -> Image.Image:
"""Overlay a color gradient onto an image.
Args:
Expand Down
27 changes: 13 additions & 14 deletions tagstudio/src/qt/helpers/file_opener.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

import structlog
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QLabel
from PySide6.QtGui import QMouseEvent
from PySide6.QtWidgets import QLabel, QWidget

logger = structlog.get_logger(__name__)

Expand All @@ -34,11 +35,11 @@ def open_file(path: str | Path, file_manager: bool = False):
normpath = Path(path).resolve().as_posix()
if file_manager:
command_name = "explorer"
command_args = '/select,"' + normpath + '"'
_command_args = '/select,"' + normpath + '"'
# For some reason, if the args are passed in a list, this will error when the
# path has spaces, even while surrounded in double quotes.
subprocess.Popen(
command_name + command_args,
command_name + _command_args,
shell=True,
close_fds=True,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP
Expand Down Expand Up @@ -92,15 +93,15 @@ def __init__(self, filepath: str | Path):
"""Initialize the FileOpenerHelper.
Args:
filepath (str): The path to the file to open.
filepath (str): The path to the file to open.
"""
self.filepath = str(filepath)

def set_filepath(self, filepath: str | Path):
"""Set the filepath to open.
Args:
filepath (str): The path to the file to open.
filepath (str): The path to the file to open.
"""
self.filepath = str(filepath)

Expand All @@ -114,34 +115,32 @@ def open_explorer(self):


class FileOpenerLabel(QLabel):
def __init__(self, text, parent=None):
def __init__(self, text: str, parent: QWidget | None = None):
"""Initialize the FileOpenerLabel.
Args:
text (str): The text to display.
parent (QWidget, optional): The parent widget. Defaults to None.
text (str): The text to display.
parent (QWidget, optional): The parent widget. Defaults to None.
"""
super().__init__(text, parent)

def set_file_path(self, filepath):
def set_file_path(self, filepath: str | Path):
"""Set the filepath to open.
Args:
filepath (str): The path to the file to open.
filepath (str): The path to the file to open.
"""
self.filepath = filepath

def mousePressEvent(self, event): # noqa: N802
def mousePressEvent(self, event: QMouseEvent): # noqa: N802
"""Handle mouse press events.
On a left click, open the file in the default file explorer.
On a right click, show a context menu.
Args:
event (QMouseEvent): The mouse press event.
event (QMouseEvent): The mouse press event.
"""
super().mousePressEvent(event)

if event.button() == Qt.MouseButton.LeftButton:
opener = FileOpenerHelper(self.filepath)
opener.open_explorer()
Expand Down
4 changes: 2 additions & 2 deletions tagstudio/src/qt/helpers/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def four_corner_gradient(


def linear_gradient(
size=tuple[int, int],
colors=list[str],
size: tuple[int, int],
colors: list[str],
interpolation: Image.Resampling = Image.Resampling.BICUBIC,
) -> Image.Image:
seed: Image.Image = Image.new(mode="RGBA", size=(len(colors), 1), color="#000000")
Expand Down
9 changes: 4 additions & 5 deletions tagstudio/src/qt/helpers/text_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from PIL import Image, ImageDraw, ImageFont


def wrap_line( # type: ignore
def wrap_line(
text: str,
font: ImageFont.ImageFont,
width: int = 256,
draw: ImageDraw.ImageDraw = None,
draw: ImageDraw.ImageDraw | None = None,
) -> int:
"""Take in a single text line and return the index it should be broken up at.
Expand All @@ -26,15 +26,14 @@ def wrap_line( # type: ignore
):
if draw.textlength(text[:i], font=font) < width:
return i
else:
return -1
return -1


def wrap_full_text(
text: str,
font: ImageFont.ImageFont,
width: int = 256,
draw: ImageDraw.ImageDraw = None,
draw: ImageDraw.ImageDraw | None = None,
) -> str:
"""Break up a string to fit the canvas given a kerning value, font size, etc."""
lines = []
Expand Down
Loading

0 comments on commit cb94fd5

Please sign in to comment.