diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b0065e7..fde8856 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -37,3 +37,19 @@ jobs: cache-dependency-path: requirements*/*.txt - run: pip install tox - run: tox run -e ${{ matrix.tox || format('py{0}', matrix.python) }} + typing: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@0ad4b8fadaa221de15dcec353f45205ec38ea70b # v4.1.4 + - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 + with: + python-version: '3.x' + cache: pip + cache-dependency-path: requirements*/*.txt + - name: cache mypy + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ./.mypy_cache + key: mypy|${{ hashFiles('pyproject.toml') }} + - run: pip install tox + - run: tox run -e typing diff --git a/pyproject.toml b/pyproject.toml index 9adbe54..ab7ff0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ classifiers = [ "Framework :: Flask", "License :: OSI Approved :: BSD License", "Programming Language :: Python", + "Typing :: Typed", ] requires-python = ">=3.8" dependencies = [ @@ -50,6 +51,12 @@ show_error_codes = true pretty = true strict = true +[[tool.mypy.overrides]] +module = [ + "sqlparse.*" +] +ignore_missing_imports = true + [tool.pyright] pythonVersion = "3.8" include = ["src/flask_debugtoolbar", "tests"] diff --git a/requirements/dev.txt b/requirements/dev.txt index a46d9e8..29842cc 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -15,6 +15,7 @@ babel==2.14.0 blinker==1.8.1 # via # -r tests.txt + # -r typing.txt # flask cachetools==5.3.3 # via tox @@ -33,6 +34,7 @@ charset-normalizer==3.3.2 click==8.1.7 # via # -r tests.txt + # -r typing.txt # flask colorama==0.4.6 # via tox @@ -54,9 +56,12 @@ filelock==3.14.0 flask==3.0.3 # via # -r tests.txt + # -r typing.txt # flask-sqlalchemy flask-sqlalchemy==3.1.1 - # via -r tests.txt + # via + # -r tests.txt + # -r typing.txt identify==2.5.36 # via pre-commit idna==3.7 @@ -71,6 +76,7 @@ importlib-metadata==7.1.0 # via # -r docs.txt # -r tests.txt + # -r typing.txt # flask # sphinx iniconfig==2.0.0 @@ -81,17 +87,20 @@ iniconfig==2.0.0 itsdangerous==2.2.0 # via # -r tests.txt + # -r typing.txt # flask jinja2==3.1.3 # via # -r docs.txt # -r tests.txt + # -r typing.txt # flask # sphinx markupsafe==2.1.5 # via # -r docs.txt # -r tests.txt + # -r typing.txt # jinja2 # werkzeug mypy==1.10.0 @@ -190,6 +199,7 @@ sphinxcontrib-serializinghtml==1.1.5 sqlalchemy==2.0.29 # via # -r tests.txt + # -r typing.txt # flask-sqlalchemy tomli==2.0.1 # via @@ -201,6 +211,16 @@ tomli==2.0.1 # tox tox==4.15.0 # via -r dev.in +types-docutils==0.21.0.20240423 + # via + # -r typing.txt + # types-pygments +types-pygments==2.17.0.20240310 + # via -r typing.txt +types-setuptools==69.5.0.20240423 + # via + # -r typing.txt + # types-pygments typing-extensions==4.11.0 # via # -r tests.txt @@ -218,11 +238,13 @@ virtualenv==20.26.1 werkzeug==3.0.2 # via # -r tests.txt + # -r typing.txt # flask zipp==3.18.1 # via # -r docs.txt # -r tests.txt + # -r typing.txt # importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/typing.in b/requirements/typing.in index 8be59c5..95d459e 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -1,3 +1,5 @@ mypy pyright pytest +types-pygments +flask-sqlalchemy diff --git a/requirements/typing.txt b/requirements/typing.txt index 3514478..59388cb 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -4,10 +4,28 @@ # # pip-compile typing.in # +blinker==1.8.1 + # via flask +click==8.1.7 + # via flask exceptiongroup==1.2.1 # via pytest +flask==3.0.3 + # via flask-sqlalchemy +flask-sqlalchemy==3.1.1 + # via -r typing.in +importlib-metadata==7.1.0 + # via flask iniconfig==2.0.0 # via pytest +itsdangerous==2.2.0 + # via flask +jinja2==3.1.3 + # via flask +markupsafe==2.1.5 + # via + # jinja2 + # werkzeug mypy==1.10.0 # via -r typing.in mypy-extensions==1.0.0 @@ -22,12 +40,26 @@ pyright==1.1.360 # via -r typing.in pytest==8.2.0 # via -r typing.in +sqlalchemy==2.0.29 + # via flask-sqlalchemy tomli==2.0.1 # via # mypy # pytest +types-docutils==0.21.0.20240423 + # via types-pygments +types-pygments==2.17.0.20240310 + # via -r typing.in +types-setuptools==69.5.0.20240423 + # via types-pygments typing-extensions==4.11.0 - # via mypy + # via + # mypy + # sqlalchemy +werkzeug==3.0.2 + # via flask +zipp==3.18.1 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/src/flask_debugtoolbar/__init__.py b/src/flask_debugtoolbar/__init__.py index 4151ab6..dbd0520 100644 --- a/src/flask_debugtoolbar/__init__.py +++ b/src/flask_debugtoolbar/__init__.py @@ -1,19 +1,26 @@ -import contextvars +from __future__ import annotations + +import collections.abc as c import importlib.metadata import os +import typing as t import urllib.parse import warnings +from contextvars import ContextVar from flask import Blueprint from flask import current_app +from flask import Flask from flask import g from flask import request from flask import send_from_directory from flask import url_for from flask.globals import request_ctx -from jinja2 import __version__ as __jinja_version__ from jinja2 import Environment from jinja2 import PackageLoader +from werkzeug import Request +from werkzeug import Response +from werkzeug.routing import Rule from .toolbar import DebugToolbar from .utils import decode_text @@ -21,11 +28,12 @@ from .utils import gzip_decompress __version__ = importlib.metadata.version("flask-debugtoolbar") +_jinja_version = importlib.metadata.version("jinja2") -module = Blueprint("debugtoolbar", __name__) +module: Blueprint = Blueprint("debugtoolbar", __name__) -def replace_insensitive(string, target, replacement): +def replace_insensitive(string: str, target: str, replacement: str) -> str: """Similar to string.replace() but is case insensitive Code borrowed from: http://forums.devshed.com/python-programming-11/case-insensitive-string-replace-490921.html @@ -39,7 +47,7 @@ def replace_insensitive(string, target, replacement): return string -def _printable(value): +def _printable(value: object) -> str: try: return decode_text(repr(value)) except Exception as e: @@ -52,19 +60,21 @@ class DebugToolbarExtension: _toolbar_codes = [200, 201, 400, 401, 403, 404, 405, 500, 501, 502, 503, 504] _redirect_codes = [301, 302, 303, 304] - def __init__(self, app=None): + def __init__(self, app: Flask | None = None) -> None: self.app = app # Support threads running `flask.copy_current_request_context` without # poping toolbar during `teardown_request` - self.debug_toolbars_var = contextvars.ContextVar("debug_toolbars") + self.debug_toolbars_var: ContextVar[dict[Request, DebugToolbar]] = ContextVar( + "debug_toolbars" + ) jinja_extensions = ["jinja2.ext.i18n"] - if __jinja_version__[0] == "2": + if _jinja_version[0] == "2": jinja_extensions.append("jinja2.ext.with_") # Configure jinja for the internal templates and add url rules # for static data - self.jinja_env = Environment( + self.jinja_env: Environment = Environment( autoescape=True, extensions=jinja_extensions, loader=PackageLoader(__name__, "templates"), @@ -76,7 +86,7 @@ def __init__(self, app=None): if app is not None: self.init_app(app) - def init_app(self, app): + def init_app(self, app: Flask) -> None: for k, v in self._default_config(app).items(): app.config.setdefault(k, v) @@ -96,7 +106,7 @@ def init_app(self, app): app.teardown_request(self.teardown_request) # Monkey-patch the Flask.dispatch_request method - app.dispatch_request = self.dispatch_request + app.dispatch_request = self.dispatch_request # type: ignore[method-assign] app.add_url_rule( "/_debug_toolbar/static/", @@ -106,7 +116,7 @@ def init_app(self, app): app.register_blueprint(module, url_prefix="/_debug_toolbar/views") - def _default_config(self, app): + def _default_config(self, app: Flask) -> dict[str, t.Any]: return { "DEBUG_TB_ENABLED": app.debug, "DEBUG_TB_HOSTS": (), @@ -127,30 +137,32 @@ def _default_config(self, app): "SQLALCHEMY_RECORD_QUERIES": app.debug, } - def dispatch_request(self): - """Modified version of Flask.dispatch_request to call process_view.""" + def dispatch_request(self) -> t.Any: + """Modified version of ``Flask.dispatch_request`` to call + :meth:`process_view`. + """ + # self references this extension, use current_app to call app methods. + app = current_app._get_current_object() # type: ignore[attr-defined] req = request_ctx.request - app = current_app if req.routing_exception is not None: app.raise_routing_exception(req) - rule = req.url_rule + rule: Rule = req.url_rule # type: ignore[assignment] - # if we provide automatic options for this URL and the - # request came with the OPTIONS method, reply automatically if ( getattr(rule, "provide_automatic_options", False) and req.method == "OPTIONS" ): return app.make_default_options_response() - # otherwise dispatch to the handler for that endpoint view_func = app.view_functions[rule.endpoint] - view_func = self.process_view(app, view_func, req.view_args) - return view_func(**req.view_args) + view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment] + # allow each toolbar to process the view and args + view_func = self.process_view(app, view_func, view_args) + return view_func(**view_args) - def _show_toolbar(self): + def _show_toolbar(self) -> bool: """Return a boolean to indicate if we need to show the toolbar.""" if request.blueprint == "debugtoolbar": return False @@ -162,17 +174,17 @@ def _show_toolbar(self): return True - def send_static_file(self, filename): + def send_static_file(self, filename: str) -> Response: """Send a static file from the flask-debugtoolbar static directory.""" return send_from_directory(self._static_dir, filename) - def process_request(self): + def process_request(self) -> None: g.debug_toolbar = self if not self._show_toolbar(): return - real_request = request._get_current_object() + real_request = request._get_current_object() # type: ignore[attr-defined] self.debug_toolbars_var.set({}) self.debug_toolbars_var.get()[real_request] = DebugToolbar( real_request, self.jinja_env @@ -181,11 +193,16 @@ def process_request(self): for panel in self.debug_toolbars_var.get()[real_request].panels: panel.process_request(real_request) - def process_view(self, app, view_func, view_kwargs): + def process_view( + self, + app: Flask, + view_func: c.Callable[..., t.Any], + view_kwargs: dict[str, t.Any], + ) -> c.Callable[..., t.Any]: """This method is called just before the flask view is called. This is done by the dispatch_request method. """ - real_request = request._get_current_object() + real_request = request._get_current_object() # type: ignore[attr-defined] try: toolbar = self.debug_toolbars_var.get({})[real_request] @@ -200,8 +217,8 @@ def process_view(self, app, view_func, view_kwargs): return view_func - def process_response(self, response): - real_request = request._get_current_object() + def process_response(self, response: Response) -> Response: + real_request = request._get_current_object() # type: ignore[attr-defined] if real_request not in self.debug_toolbars_var.get({}): return response @@ -219,7 +236,7 @@ def process_response(self, response): {"redirect_to": redirect_to, "redirect_code": redirect_code}, ) response.content_length = len(content) - response.location = None + del response.location response.response = [content] response.status_code = 200 @@ -263,20 +280,21 @@ def process_response(self, response): toolbar_html = toolbar.render_toolbar() content = "".join((before, toolbar_html, after)) - content = content.encode("utf-8") + content_bytes = content.encode("utf-8") if content_encoding and "gzip" in content_encoding: - content = gzip_compress(content) + content_bytes = gzip_compress(content_bytes) - response.response = [content] - response.content_length = len(content) + response.response = [content_bytes] + response.content_length = len(content_bytes) return response - def teardown_request(self, exc): + def teardown_request(self, exc: BaseException | None) -> None: # debug_toolbars_var won't be set under `flask.copy_current_request_context` - self.debug_toolbars_var.get({}).pop(request._get_current_object(), None) + real_request = request._get_current_object() # type: ignore[attr-defined] + self.debug_toolbars_var.get({}).pop(real_request, None) - def render(self, template_name, context): + def render(self, template_name: str, context: dict[str, t.Any]) -> str: template = self.jinja_env.get_template(template_name) return template.render(**context) diff --git a/src/flask_debugtoolbar/panels/__init__.py b/src/flask_debugtoolbar/panels/__init__.py index 3d34966..9a6486a 100644 --- a/src/flask_debugtoolbar/panels/__init__.py +++ b/src/flask_debugtoolbar/panels/__init__.py @@ -1,7 +1,18 @@ +from __future__ import annotations + +import collections.abc as c +import typing as t + +from flask import Flask +from jinja2 import Environment +from werkzeug import Request +from werkzeug import Response + + class DebugPanel: """Base class for debug panels.""" - # name = Base + name: str # If content returns something, set to true in subclass has_content = False @@ -11,10 +22,12 @@ class DebugPanel: # We'll maintain a local context instance so we can expose our template # context variables to panels which need them: - context = {} + context: dict[str, t.Any] = {} # Panel methods - def __init__(self, jinja_env, context=None): + def __init__( + self, jinja_env: Environment, context: dict[str, t.Any] | None = None + ) -> None: if context is not None: self.context.update(context) @@ -23,7 +36,7 @@ def __init__(self, jinja_env, context=None): self.is_active = False @classmethod - def init_app(cls, app): + def init_app(cls, app: Flask) -> None: """Method that can be overridden by child classes. Can be used for setting up additional URL-rules/routes. @@ -45,37 +58,42 @@ def serve_generated_image(cls, app): """ pass - def render(self, template_name, context): + def render(self, template_name: str, context: dict[str, t.Any]) -> str: template = self.jinja_env.get_template(template_name) return template.render(**context) - def dom_id(self): + def dom_id(self) -> str: return f"flDebug{self.name.replace(' ', '')}Panel" - def nav_title(self): + def nav_title(self) -> str: """Title showing in toolbar""" raise NotImplementedError - def nav_subtitle(self): + def nav_subtitle(self) -> str: """Subtitle showing until title in toolbar""" return "" - def title(self): + def title(self) -> str: """Title showing in panel""" raise NotImplementedError - def url(self): + def url(self) -> str: raise NotImplementedError - def content(self): + def content(self) -> str: raise NotImplementedError # Standard middleware methods - def process_request(self, request): + def process_request(self, request: Request) -> None: pass - def process_view(self, request, view_func, view_kwargs): + def process_view( + self, + request: Request, + view_func: c.Callable[..., t.Any], + view_kwargs: dict[str, t.Any], + ) -> c.Callable[..., t.Any] | None: pass - def process_response(self, request, response): + def process_response(self, request: Request, response: Response) -> None: pass diff --git a/src/flask_debugtoolbar/panels/config_vars.py b/src/flask_debugtoolbar/panels/config_vars.py index c2ab464..7fc6807 100644 --- a/src/flask_debugtoolbar/panels/config_vars.py +++ b/src/flask_debugtoolbar/panels/config_vars.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from flask import current_app from . import DebugPanel @@ -9,16 +11,16 @@ class ConfigVarsDebugPanel(DebugPanel): name = "ConfigVars" has_content = True - def nav_title(self): + def nav_title(self) -> str: return "Config" - def title(self): + def title(self) -> str: return "Config" - def url(self): + def url(self) -> str: return "" - def content(self): + def content(self) -> str: context = self.context.copy() context.update( { diff --git a/src/flask_debugtoolbar/panels/g.py b/src/flask_debugtoolbar/panels/g.py index da7dc14..7387792 100644 --- a/src/flask_debugtoolbar/panels/g.py +++ b/src/flask_debugtoolbar/panels/g.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from flask import g from . import DebugPanel @@ -9,16 +11,16 @@ class GDebugPanel(DebugPanel): name = "g" has_content = True - def nav_title(self): + def nav_title(self) -> str: return "flask.g" - def title(self): + def title(self) -> str: return "flask.g content" - def url(self): + def url(self) -> str: return "" - def content(self): + def content(self) -> str: context = self.context.copy() context.update({"g_content": g.__dict__}) return self.render("panels/g.html", context) diff --git a/src/flask_debugtoolbar/panels/headers.py b/src/flask_debugtoolbar/panels/headers.py index b0e6717..b698d14 100644 --- a/src/flask_debugtoolbar/panels/headers.py +++ b/src/flask_debugtoolbar/panels/headers.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +import typing as t + +from werkzeug import Request + from . import DebugPanel @@ -7,7 +13,7 @@ class HeaderDebugPanel(DebugPanel): name = "Header" has_content = True # List of headers we want to display - header_filter = ( + header_filter: tuple[str, ...] = ( "CONTENT_TYPE", "HTTP_ACCEPT", "HTTP_ACCEPT_CHARSET", @@ -30,25 +36,21 @@ class HeaderDebugPanel(DebugPanel): "SERVER_SOFTWARE", ) - def nav_title(self): + def nav_title(self) -> str: return "HTTP Headers" - def title(self): + def title(self) -> str: return "HTTP Headers" - def url(self): + def url(self) -> str: return "" - def process_request(self, request): - self.headers = dict( - [ - (k, request.environ[k]) - for k in self.header_filter - if k in request.environ - ] - ) + def process_request(self, request: Request) -> None: + self.headers: dict[str, t.Any] = { + k: request.environ[k] for k in self.header_filter if k in request.environ + } - def content(self): + def content(self) -> str: context = self.context.copy() context.update({"headers": self.headers}) return self.render("panels/headers.html", context) diff --git a/src/flask_debugtoolbar/panels/logger.py b/src/flask_debugtoolbar/panels/logger.py index b8528e1..429c90b 100644 --- a/src/flask_debugtoolbar/panels/logger.py +++ b/src/flask_debugtoolbar/panels/logger.py @@ -1,20 +1,27 @@ +from __future__ import annotations + import datetime import logging import threading +from werkzeug import Request + from ..utils import format_fname from . import DebugPanel class ThreadTrackingHandler(logging.Handler): - def __init__(self): + def __init__(self) -> None: super().__init__() - self.records = {} # a dictionary that maps threads to log records + # a dictionary that maps threads to log records + self.records: dict[threading.Thread, list[logging.LogRecord]] = {} - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: self.get_records().append(record) - def get_records(self, thread=None): + def get_records( + self, thread: threading.Thread | None = None + ) -> list[logging.LogRecord]: """ Returns a list of records for the provided thread, of if none is provided, returns a list for the current thread. @@ -27,7 +34,7 @@ def get_records(self, thread=None): return self.records[thread] - def clear_records(self, thread=None): + def clear_records(self, thread: threading.Thread | None = None) -> None: if thread is None: thread = threading.current_thread() @@ -35,11 +42,11 @@ def clear_records(self, thread=None): del self.records[thread] -handler = None +handler: ThreadTrackingHandler = None # type: ignore[assignment] _init_lock = threading.Lock() -def _init_once(): +def _init_once() -> None: global handler if handler is not None: @@ -65,30 +72,30 @@ class LoggingPanel(DebugPanel): name = "Logging" has_content = True - def process_request(self, request): + def process_request(self, request: Request) -> None: _init_once() handler.clear_records() - def get_and_delete(self): + def get_and_delete(self) -> list[logging.LogRecord]: records = handler.get_records() handler.clear_records() return records - def nav_title(self): + def nav_title(self) -> str: return "Logging" - def nav_subtitle(self): + def nav_subtitle(self) -> str: num_records = len(handler.get_records()) plural = "message" if num_records == 1 else "messages" return f"{num_records} {plural}" - def title(self): + def title(self) -> str: return "Log Messages" - def url(self): + def url(self) -> str: return "" - def content(self): + def content(self) -> str: records = [] for record in self.get_and_delete(): diff --git a/src/flask_debugtoolbar/panels/profiler.py b/src/flask_debugtoolbar/panels/profiler.py index 014f0ba..749edce 100644 --- a/src/flask_debugtoolbar/panels/profiler.py +++ b/src/flask_debugtoolbar/panels/profiler.py @@ -1,7 +1,14 @@ +from __future__ import annotations + +import collections.abc as c import functools import pstats +import typing as t from flask import current_app +from jinja2 import Environment +from werkzeug import Request +from werkzeug import Response from ..utils import format_fname from . import DebugPanel @@ -9,7 +16,7 @@ try: import cProfile as profile except ImportError: - import profile + import profile # type: ignore[no-redef] class ProfilerDebugPanel(DebugPanel): @@ -18,7 +25,15 @@ class ProfilerDebugPanel(DebugPanel): name = "Profiler" user_activate = True - def __init__(self, jinja_env, context=None): + is_active: bool = False + dump_filename: str | None = None + profiler: profile.Profile + stats: pstats.Stats | None = None + function_calls: list[dict[str, t.Any]] + + def __init__( + self, jinja_env: Environment, context: dict[str, t.Any] | None = None + ) -> None: super().__init__(jinja_env, context=context) if current_app.config.get("DEBUG_TB_PROFILER_ENABLED"): @@ -27,40 +42,48 @@ def __init__(self, jinja_env, context=None): "DEBUG_TB_PROFILER_DUMP_FILENAME" ) - def has_content(self): + @property + def has_content(self) -> bool: # type: ignore[override] return bool(self.profiler) - def process_request(self, request): + def process_request(self, request: Request) -> None: if not self.is_active: return - self.profiler = profile.Profile() + self.profiler = profile.Profile() # pyright: ignore self.stats = None - def process_view(self, request, view_func, view_kwargs): + def process_view( + self, + request: Request, + view_func: c.Callable[..., t.Any], + view_kwargs: dict[str, t.Any], + ) -> c.Callable[..., t.Any] | None: if self.is_active: func = functools.partial(self.profiler.runcall, view_func) functools.update_wrapper(func, view_func) return func - def process_response(self, request, response): + return None + + def process_response(self, request: Request, response: Response) -> None: if not self.is_active: - return False + return if self.profiler is not None: - self.profiler.disable() + self.profiler.disable() # pyright: ignore try: stats = pstats.Stats(self.profiler) except TypeError: self.is_active = False - return False + return - function_calls = [] + function_calls: list[dict[str, t.Any]] = [] - for func in stats.sort_stats(1).fcn_list: - current = {} - info = stats.stats[func] + for func in stats.sort_stats(1).fcn_list: # type: ignore[attr-defined] + current: dict[str, t.Any] = {} + info = stats.stats[func] # type: ignore[attr-defined] # Number of calls if info[0] != info[1]: @@ -88,7 +111,7 @@ def process_response(self, request, response): current["percall_cum"] = 0 # Filename - filename = pstats.func_std_string(func) + filename = pstats.func_std_string(func) # type: ignore[attr-defined] current["filename_long"] = filename current["filename"] = format_fname(filename) function_calls.append(current) @@ -104,27 +127,25 @@ def process_response(self, request, response): self.profiler.dump_stats(filename) - return response - - def title(self): + def title(self) -> str: if not self.is_active: return "Profiler not active" - return f"View: {float(self.stats.total_tt) * 1000:.2f}ms" + return f"View: {float(self.stats.total_tt) * 1000:.2f}ms" # type: ignore[union-attr] - def nav_title(self): + def nav_title(self) -> str: return "Profiler" - def nav_subtitle(self): + def nav_subtitle(self) -> str: if not self.is_active: return "in-active" - return f"View: {float(self.stats.total_tt) * 1000:.2f}ms" + return f"View: {float(self.stats.total_tt) * 1000:.2f}ms" # type: ignore[union-attr] - def url(self): + def url(self) -> str: return "" - def content(self): + def content(self) -> str: if not self.is_active: return "The profiler is not activated, activate it to use it" diff --git a/src/flask_debugtoolbar/panels/request_vars.py b/src/flask_debugtoolbar/panels/request_vars.py index a886090..113947b 100644 --- a/src/flask_debugtoolbar/panels/request_vars.py +++ b/src/flask_debugtoolbar/panels/request_vars.py @@ -1,4 +1,10 @@ +from __future__ import annotations + +import collections.abc as c +import typing as t + from flask import session +from werkzeug import Request from . import DebugPanel @@ -9,27 +15,31 @@ class RequestVarsDebugPanel(DebugPanel): name = "RequestVars" has_content = True - def nav_title(self): + def nav_title(self) -> str: return "Request Vars" - def title(self): + def title(self) -> str: return "Request Vars" - def url(self): + def url(self) -> str: return "" - def process_request(self, request): + def process_request(self, request: Request) -> None: self.request = request self.session = session - self.view_func = None - self.view_args = [] - self.view_kwargs = {} - - def process_view(self, request, view_func, view_kwargs): + self.view_func: c.Callable[..., t.Any] | None = None + self.view_kwargs: dict[str, t.Any] = {} + + def process_view( + self, + request: Request, + view_func: c.Callable[..., t.Any], + view_kwargs: dict[str, t.Any], + ) -> None: self.view_func = view_func self.view_kwargs = view_kwargs - def content(self): + def content(self) -> str: context = self.context.copy() context.update( { @@ -41,7 +51,6 @@ def content(self): if self.view_func else "[unknown]" ), - "view_args": self.view_args, "view_kwargs": self.view_kwargs or {}, "session": self.session.items(), } diff --git a/src/flask_debugtoolbar/panels/route_list.py b/src/flask_debugtoolbar/panels/route_list.py index 461815c..d24782b 100644 --- a/src/flask_debugtoolbar/panels/route_list.py +++ b/src/flask_debugtoolbar/panels/route_list.py @@ -1,4 +1,8 @@ +from __future__ import annotations + from flask import current_app +from werkzeug import Request +from werkzeug.routing import Rule from . import DebugPanel @@ -8,30 +12,30 @@ class RouteListDebugPanel(DebugPanel): name = "RouteList" has_content = True - routes = [] + routes: list[Rule] = [] - def nav_title(self): + def nav_title(self) -> str: return "Route List" - def title(self): + def title(self) -> str: return "Route List" - def url(self): + def url(self) -> str: return "" - def nav_subtitle(self): + def nav_subtitle(self) -> str: count = len(self.routes) plural = "route" if count == 1 else "routes" return f"{count} {plural}" - def process_request(self, request): + def process_request(self, request: Request) -> None: self.routes = [ rule for rule in current_app.url_map.iter_rules() if not rule.rule.startswith("/_debug_toolbar") ] - def content(self): + def content(self) -> str: return self.render( "panels/route_list.html", { diff --git a/src/flask_debugtoolbar/panels/sqlalchemy.py b/src/flask_debugtoolbar/panels/sqlalchemy.py index 49e6db4..3cd4c6e 100644 --- a/src/flask_debugtoolbar/panels/sqlalchemy.py +++ b/src/flask_debugtoolbar/panels/sqlalchemy.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import typing as t + import itsdangerous from flask import abort from flask import current_app @@ -12,40 +16,48 @@ try: from flask_sqlalchemy import SQLAlchemy except ImportError: - sqlalchemy_available = False - get_recorded_queries = SQLAlchemy = None - debug_enables_record_queries = False + sqlalchemy_available: bool = False + get_recorded_queries = SQLAlchemy = None # type: ignore[misc, assignment] + debug_enables_record_queries: bool = False else: try: - from flask_sqlalchemy.record_queries import get_recorded_queries + from flask_sqlalchemy.record_queries import ( # type: ignore[assignment] + get_recorded_queries, + ) debug_enables_record_queries = False except ImportError: # For flask_sqlalchemy < 3.0.0 - from flask_sqlalchemy import get_debug_queries as get_recorded_queries + from flask_sqlalchemy import ( # type: ignore[no-redef] + get_debug_queries as get_recorded_queries, + ) # flask_sqlalchemy < 3.0.0 automatically enabled # SQLALCHEMY_RECORD_QUERIES in debug or test mode debug_enables_record_queries = True - location_property = "context" + location_property: str = "context" else: location_property = "location" sqlalchemy_available = True -def query_signer(): +def query_signer() -> itsdangerous.URLSafeSerializer: return itsdangerous.URLSafeSerializer( current_app.config["SECRET_KEY"], salt="fdt-sql-query" ) -def is_select(statement): - prefix = b"select" if isinstance(statement, bytes) else "select" - return statement.lower().strip().startswith(prefix) +def is_select(statement: str | bytes) -> bool: + statement = statement.lower().strip() + + if isinstance(statement, bytes): + return statement.startswith(b"select") + return statement.startswith("select") # pyright: ignore -def dump_query(statement, params): + +def dump_query(statement: str, params: t.Any) -> str | None: if not params or not is_select(statement): return None @@ -55,9 +67,9 @@ def dump_query(statement, params): return None -def load_query(data): +def load_query(data: str) -> tuple[str, t.Any]: try: - statement, params = query_signer().loads(request.args["query"]) + statement, params = query_signer().loads(data) except (itsdangerous.BadSignature, TypeError): abort(406) @@ -68,21 +80,21 @@ def load_query(data): return statement, params -def extension_used(): +def extension_used() -> bool: return "sqlalchemy" in current_app.extensions -def recording_enabled(): +def recording_enabled() -> bool: return ( debug_enables_record_queries and current_app.debug - ) or current_app.config.get("SQLALCHEMY_RECORD_QUERIES") + ) or current_app.config.get("SQLALCHEMY_RECORD_QUERIES", False) -def is_available(): +def is_available() -> bool: return sqlalchemy_available and extension_used() and recording_enabled() -def get_queries(): +def get_queries() -> list[t.Any]: if get_recorded_queries: return get_recorded_queries() else: @@ -95,19 +107,13 @@ class SQLAlchemyDebugPanel(DebugPanel): name = "SQLAlchemy" @property - def has_content(self): + def has_content(self) -> bool: # type: ignore[override] return bool(get_queries()) or not is_available() - def process_request(self, request): - pass - - def process_response(self, request, response): - pass - - def nav_title(self): + def nav_title(self) -> str: return "SQLAlchemy" - def nav_subtitle(self): + def nav_subtitle(self) -> str: count = len(get_queries()) if not count and not is_available(): @@ -116,13 +122,13 @@ def nav_subtitle(self): plural = "query" if count == 1 else "queries" return f"{count} {plural}" - def title(self): + def title(self) -> str: return "SQLAlchemy queries" - def url(self): + def url(self) -> str: return "" - def content(self): + def content(self) -> str: queries = get_queries() if not queries and not is_available(): @@ -158,9 +164,9 @@ def content(self): @module.route( "/sqlalchemy/sql_explain", methods=["GET", "POST"], defaults=dict(explain=True) ) -def sql_select(explain=False): +def sql_select(explain: bool = False) -> str: statement, params = load_query(request.args["query"]) - engine = SQLAlchemy().get_engine(current_app) + engine = current_app.extensions["sqlalchemy"].engine if explain: if engine.driver == "pysqlite": @@ -169,7 +175,7 @@ def sql_select(explain=False): statement = f"EXPLAIN\n{statement}" result = engine.execute(statement, params) - return g.debug_toolbar.render( + return g.debug_toolbar.render( # type: ignore[no-any-return] "panels/sqlalchemy_select.html", { "result": result.fetchall(), diff --git a/src/flask_debugtoolbar/panels/template.py b/src/flask_debugtoolbar/panels/template.py index 6a6250c..71e55cb 100644 --- a/src/flask_debugtoolbar/panels/template.py +++ b/src/flask_debugtoolbar/panels/template.py @@ -1,7 +1,10 @@ -import collections +from __future__ import annotations + import json import sys +import typing as t import uuid +from collections import deque from flask import abort from flask import current_app @@ -10,6 +13,7 @@ from flask import Response from flask import template_rendered from flask import url_for +from jinja2 import Template from .. import module from . import DebugPanel @@ -22,23 +26,23 @@ class TemplateDebugPanel(DebugPanel): has_content = True # save the context for the 5 most recent requests - template_cache = collections.deque(maxlen=5) + template_cache: deque[tuple[str, list[dict[str, t.Any]]]] = deque(maxlen=5) @classmethod - def get_cache_for_key(self, key): - for cache_key, value in self.template_cache: + def get_cache_for_key(cls, key: str) -> list[dict[str, t.Any]]: + for cache_key, value in cls.template_cache: if key == cache_key: return value raise KeyError(key) - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) - self.key = str(uuid.uuid4()) - self.templates = [] + self.key: str = str(uuid.uuid4()) + self.templates: list[dict[str, t.Any]] = [] template_rendered.connect(self._store_template_info) - def _store_template_info(self, sender, **kwargs): + def _store_template_info(self, sender: t.Any, **kwargs: t.Any) -> None: # only record in the cache if the editor is enabled and there is # actually a template for this request if not self.templates and is_editor_enabled(): @@ -46,25 +50,19 @@ def _store_template_info(self, sender, **kwargs): self.templates.append(kwargs) - def process_request(self, request): - pass - - def process_response(self, request, response): - pass - - def nav_title(self): + def nav_title(self) -> str: return "Templates" - def nav_subtitle(self): + def nav_subtitle(self) -> str: return f"{len(self.templates)} rendered" - def title(self): + def title(self) -> str: return "Templates" - def url(self): + def url(self) -> str: return "" - def content(self): + def content(self) -> str: return self.render( "panels/template.html", { @@ -75,33 +73,36 @@ def content(self): ) -def is_editor_enabled(): - return current_app.config.get("DEBUG_TB_TEMPLATE_EDITOR_ENABLED") +def is_editor_enabled() -> bool: + return current_app.config.get("DEBUG_TB_TEMPLATE_EDITOR_ENABLED", False) # type: ignore -def require_enabled(): +def require_enabled() -> None: if not is_editor_enabled(): abort(403) -def _get_source(template): +def _get_source(template: Template) -> str: + if template.filename is None: + return "" + with open(template.filename, "rb") as fp: source = fp.read() return source.decode(_template_encoding()) -def _template_encoding(): +def _template_encoding() -> str: return getattr(current_app.jinja_loader, "encoding", "utf-8") @module.route("/template/") -def template_editor(key): +def template_editor(key: str) -> str: require_enabled() # TODO set up special loader that caches templates it loads # and can override template contents templates = [t["template"] for t in TemplateDebugPanel.get_cache_for_key(key)] - return g.debug_toolbar.render( + return g.debug_toolbar.render( # type: ignore[no-any-return] "panels/template_editor.html", { "static_path": url_for("_debug_toolbar.static", filename=""), @@ -114,7 +115,7 @@ def template_editor(key): @module.route("/template//save", methods=["POST"]) -def save_template(key): +def save_template(key: str) -> str: require_enabled() template = TemplateDebugPanel.get_cache_for_key(key)[0]["template"] content = request.form["content"].encode(_template_encoding()) @@ -126,7 +127,7 @@ def save_template(key): @module.route("/template/", methods=["POST"]) -def template_preview(key): +def template_preview(key: str) -> str | Response: require_enabled() context = TemplateDebugPanel.get_cache_for_key(key)[0]["context"] content = request.form["content"] @@ -139,10 +140,10 @@ def template_preview(key): tb = sys.exc_info()[2] try: - while tb.tb_next: - tb = tb.tb_next + while tb.tb_next: # type: ignore[union-attr] + tb = tb.tb_next # type: ignore[union-attr] - msg = {"lineno": tb.tb_lineno, "error": str(e)} + msg = {"lineno": tb.tb_lineno, "error": str(e)} # type: ignore[union-attr] return Response(json.dumps(msg), status=400, mimetype="application/json") finally: del tb diff --git a/src/flask_debugtoolbar/panels/timer.py b/src/flask_debugtoolbar/panels/timer.py index 2478ae0..3fb0f3c 100644 --- a/src/flask_debugtoolbar/panels/timer.py +++ b/src/flask_debugtoolbar/panels/timer.py @@ -1,5 +1,10 @@ +from __future__ import annotations + import time +from werkzeug import Request +from werkzeug import Response + from . import DebugPanel try: @@ -16,22 +21,22 @@ class TimerDebugPanel(DebugPanel): name = "Timer" has_content = HAVE_RESOURCE - def process_request(self, request): + def process_request(self, request: Request) -> None: self._start_time = time.time() if HAVE_RESOURCE: self._start_rusage = resource.getrusage(resource.RUSAGE_SELF) - def process_response(self, request, response): - self.total_time = (time.time() - self._start_time) * 1000 + def process_response(self, request: Request, response: Response) -> None: + self.total_time: float = (time.time() - self._start_time) * 1000 if HAVE_RESOURCE: self._end_rusage = resource.getrusage(resource.RUSAGE_SELF) - def nav_title(self): + def nav_title(self) -> str: return "Time" - def nav_subtitle(self): + def nav_subtitle(self) -> str: if not HAVE_RESOURCE: return f"TOTAL: {self.total_time:0.2f}ms" @@ -39,16 +44,16 @@ def nav_subtitle(self): stime = self._end_rusage.ru_stime - self._start_rusage.ru_stime return f"CPU: {(utime + stime) * 1000.0:0.2f}ms ({self.total_time:0.2f}ms)" - def title(self): + def title(self) -> str: return "Resource Usage" - def url(self): + def url(self) -> str: return "" - def _elapsed_ru(self, name): - return getattr(self._end_rusage, name) - getattr(self._start_rusage, name) + def _elapsed_ru(self, name: str) -> float: + return getattr(self._end_rusage, name) - getattr(self._start_rusage, name) # type: ignore[no-any-return] - def content(self): + def content(self) -> str: utime = 1000 * self._elapsed_ru("ru_utime") stime = 1000 * self._elapsed_ru("ru_stime") vcsw = self._elapsed_ru("ru_nvcsw") diff --git a/src/flask_debugtoolbar/panels/versions.py b/src/flask_debugtoolbar/panels/versions.py index 2fd07a7..6b70cb0 100644 --- a/src/flask_debugtoolbar/panels/versions.py +++ b/src/flask_debugtoolbar/panels/versions.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import importlib.metadata import os from sysconfig import get_path from . import DebugPanel -flask_version = importlib.metadata.version("flask") +flask_version: str = importlib.metadata.version("flask") class VersionDebugPanel(DebugPanel): @@ -13,19 +15,19 @@ class VersionDebugPanel(DebugPanel): name = "Version" has_content = True - def nav_title(self): + def nav_title(self) -> str: return "Versions" - def nav_subtitle(self): + def nav_subtitle(self) -> str: return f"Flask {flask_version}" - def url(self): + def url(self) -> str: return "" - def title(self): + def title(self) -> str: return "Versions" - def content(self): + def content(self) -> str: packages_metadata = [p.metadata for p in importlib.metadata.distributions()] packages = sorted(packages_metadata, key=lambda p: p["Name"].lower()) return self.render( diff --git a/src/flask_debugtoolbar/py.typed b/src/flask_debugtoolbar/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/flask_debugtoolbar/templates/panels/request_vars.html b/src/flask_debugtoolbar/templates/panels/request_vars.html index a018224..aa871d9 100644 --- a/src/flask_debugtoolbar/templates/panels/request_vars.html +++ b/src/flask_debugtoolbar/templates/panels/request_vars.html @@ -4,14 +4,12 @@

View information

View Function - args kwargs {{ view_func }} - {{ view_args|default("None") }} {% if view_kwargs.items() %} {% for k, v in view_kwargs.items() %} diff --git a/src/flask_debugtoolbar/toolbar.py b/src/flask_debugtoolbar/toolbar.py index 51e26d9..97c4e2c 100644 --- a/src/flask_debugtoolbar/toolbar.py +++ b/src/flask_debugtoolbar/toolbar.py @@ -1,28 +1,35 @@ +from __future__ import annotations + +import collections.abc as c +import typing as t from urllib.parse import unquote from flask import current_app +from flask import Flask from flask import url_for +from jinja2 import Environment +from werkzeug import Request from werkzeug.utils import import_string +from .panels import DebugPanel + class DebugToolbar: - _cached_panel_classes = {} + _cached_panel_classes: t.ClassVar[dict[str, type[DebugPanel] | None]] = {} - def __init__(self, request, jinja_env): + def __init__(self, request: Request, jinja_env: Environment) -> None: self.jinja_env = jinja_env self.request = request - self.panels = [] - - self.template_context = { + self.panels: list[DebugPanel] = [] + self.template_context: dict[str, t.Any] = { "static_path": url_for("_debug_toolbar.static", filename="") } - self.create_panels() - def create_panels(self): + def create_panels(self) -> None: """Populate debug panels""" - activated = self.request.cookies.get("fldt_active", "") - activated = unquote(activated).split(";") + activated_str = self.request.cookies.get("fldt_active", "") + activated = unquote(activated_str).split(";") for panel_class in self._iter_panels(current_app): panel_instance = panel_class( @@ -34,21 +41,20 @@ def create_panels(self): self.panels.append(panel_instance) - def render_toolbar(self): + def render_toolbar(self) -> str: context = self.template_context.copy() context.update({"panels": self.panels}) - template = self.jinja_env.get_template("base.html") return template.render(**context) @classmethod - def load_panels(cls, app): + def load_panels(cls, app: Flask) -> None: for panel_class in cls._iter_panels(app): # Call `.init_app()` on panels panel_class.init_app(app) @classmethod - def _iter_panels(cls, app): + def _iter_panels(cls, app: Flask) -> c.Iterator[type[DebugPanel]]: for panel_path in app.config["DEBUG_TB_PANELS"]: panel_class = cls._import_panel(app, panel_path) @@ -56,7 +62,7 @@ def _iter_panels(cls, app): yield panel_class @classmethod - def _import_panel(cls, app, path): + def _import_panel(cls, app: Flask, path: str) -> type[DebugPanel] | None: cache = cls._cached_panel_classes try: @@ -65,7 +71,7 @@ def _import_panel(cls, app, path): pass try: - panel_class = import_string(path) + panel_class: type[DebugPanel] | None = import_string(path) except ImportError as e: app.logger.warning("Disabled %s due to ImportError: %s", path, e) panel_class = None diff --git a/src/flask_debugtoolbar/utils.py b/src/flask_debugtoolbar/utils.py index 50144a3..d16c030 100644 --- a/src/flask_debugtoolbar/utils.py +++ b/src/flask_debugtoolbar/utils.py @@ -1,8 +1,12 @@ +from __future__ import annotations + +import collections.abc as c import gzip import io import itertools import os.path import sys +from types import ModuleType from flask import current_app from markupsafe import Markup @@ -19,14 +23,14 @@ HAVE_PYGMENTS = False try: - import sqlparse + import sqlparse # pyright: ignore HAVE_SQLPARSE = True except ImportError: HAVE_SQLPARSE = False -def format_fname(value): +def format_fname(value: str) -> str: # If the value has a builtin prefix, return it unchanged if value.startswith(("{", "<")): return value @@ -46,12 +50,16 @@ def format_fname(value): return f"<{_shortest_relative_path(value, sys.path, os.path)}>" -def _shortest_relative_path(value, paths, path_module): +def _shortest_relative_path( + value: str, paths: list[str], path_module: ModuleType +) -> str: relpaths = _relative_paths(value, paths, path_module) return min(itertools.chain(relpaths, [value]), key=len) -def _relative_paths(value, paths, path_module): +def _relative_paths( + value: str, paths: list[str], path_module: ModuleType +) -> c.Iterator[str]: for path in paths: try: relval = path_module.relpath(value, path) @@ -64,7 +72,7 @@ def _relative_paths(value, paths, path_module): yield relval -def decode_text(value): +def decode_text(value: str | bytes) -> str: """ Decode a text-like value for display. @@ -73,11 +81,11 @@ def decode_text(value): """ if isinstance(value, bytes): return value.decode("ascii", "replace") - else: - return value + + return value # pyright: ignore -def format_sql(query, args): +def format_sql(query: str | bytes, args: object) -> str: if HAVE_SQLPARSE: query = sqlparse.format(query, reindent=True, keyword_case="upper") @@ -89,7 +97,7 @@ def format_sql(query, args): ) -def gzip_compress(data, compresslevel=6): +def gzip_compress(data: bytes, compresslevel: int = 6) -> bytes: buff = io.BytesIO() with gzip.GzipFile(fileobj=buff, mode="wb", compresslevel=compresslevel) as f: @@ -98,6 +106,6 @@ def gzip_compress(data, compresslevel=6): return buff.getvalue() -def gzip_decompress(data): +def gzip_decompress(data: bytes) -> bytes: with gzip.GzipFile(fileobj=io.BytesIO(data), mode="rb") as f: return f.read() diff --git a/tests/basic_app.py b/tests/basic_app.py index ed43534..5c76467 100644 --- a/tests/basic_app.py +++ b/tests/basic_app.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from flask import Flask from flask import render_template from flask_sqlalchemy import SQLAlchemy @@ -21,13 +23,13 @@ db = SQLAlchemy(app) -class Foo(db.Model): +class Foo(db.Model): # type: ignore[name-defined, misc] __tablename__ = "foo" id = db.Column(db.Integer, primary_key=True) @app.route("/") -def index(): +def index() -> str: Foo.query.filter_by(id=1).all() return render_template("basic_app.html") diff --git a/tests/conftest.py b/tests/conftest.py index efc3b7d..d0a2a9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import pytest @pytest.fixture(autouse=True) -def mock_env_development(monkeypatch): +def mock_env_development(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("FLASK_ENV", "development") diff --git a/tests/test_toolbar.py b/tests/test_toolbar.py index f1d6553..f392def 100644 --- a/tests/test_toolbar.py +++ b/tests/test_toolbar.py @@ -1,10 +1,16 @@ -def load_app(name): - app = __import__(name).app +from __future__ import annotations + +from flask import Flask +from flask.testing import FlaskClient + + +def load_app(name: str) -> FlaskClient: + app: Flask = __import__(name).app app.config["TESTING"] = True return app.test_client() -def test_basic_app(): +def test_basic_app() -> None: app = load_app("basic_app") index = app.get("/") assert index.status_code == 200 diff --git a/tests/test_utils.py b/tests/test_utils.py index 81ad16b..eab4114 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import ntpath import posixpath +from types import ModuleType import pytest from markupsafe import escape @@ -34,7 +37,9 @@ ("c:\\Foo\\Bar", ["c:\\foo"], ["Bar"], ntpath), ], ) -def test_relative_paths(value, paths, expected, path_module): +def test_relative_paths( + value: str, paths: list[str], expected: list[str], path_module: ModuleType +) -> None: assert list(_relative_paths(value, paths, path_module)) == expected @@ -52,22 +57,24 @@ def test_relative_paths(value, paths, expected, path_module): ("c:\\foo\\bar\\baz", ["c:\\foo", "c:\\foo\\bar"], "baz", ntpath), ], ) -def test_shortest_relative_path(value, paths, expected, path_module): +def test_shortest_relative_path( + value: str, paths: list[str], expected: str, path_module: ModuleType +) -> None: assert _shortest_relative_path(value, paths, path_module) == expected -def test_decode_text_unicode(): +def test_decode_text_unicode() -> None: value = "\uffff" decoded = decode_text(value) assert decoded == value -def test_decode_text_ascii(): +def test_decode_text_ascii() -> None: value = "abc" assert decode_text(value.encode("ascii")) == value -def test_decode_text_non_ascii(): +def test_decode_text_non_ascii() -> None: value = b"abc \xff xyz" assert isinstance(value, bytes) @@ -79,22 +86,25 @@ def test_decode_text_non_ascii(): @pytest.fixture() -def no_pygments(monkeypatch): +def no_pygments(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("flask_debugtoolbar.utils.HAVE_PYGMENTS", False) -def test_format_sql_no_pygments(no_pygments): +@pytest.mark.usefixtures("no_pygments") +def test_format_sql_no_pygments() -> None: sql = "select 1" assert format_sql(sql, {}) == sql -def test_format_sql_no_pygments_non_ascii(no_pygments): +@pytest.mark.usefixtures("no_pygments") +def test_format_sql_no_pygments_non_ascii() -> None: sql = b"select '\xff'" formatted = format_sql(sql, {}) assert formatted.startswith("select '") -def test_format_sql_no_pygments_escape_html(no_pygments): +@pytest.mark.usefixtures("no_pygments") +def test_format_sql_no_pygments_escape_html() -> None: sql = "select x < 1" formatted = format_sql(sql, {}) assert not isinstance(formatted, Markup) @@ -102,7 +112,7 @@ def test_format_sql_no_pygments_escape_html(no_pygments): @pytest.mark.skipif(not HAVE_PYGMENTS, reason='test requires the "Pygments" library') -def test_format_sql_pygments(): +def test_format_sql_pygments() -> None: sql = "select 1" html = format_sql(sql, {}) assert isinstance(html, Markup) @@ -112,7 +122,7 @@ def test_format_sql_pygments(): @pytest.mark.skipif(not HAVE_PYGMENTS, reason='test requires the "Pygments" library') -def test_format_sql_pygments_non_ascii(): +def test_format_sql_pygments_non_ascii() -> None: sql = b"select 'abc \xff xyz'" html = format_sql(sql, {}) assert isinstance(html, Markup) diff --git a/tox.ini b/tox.ini index cb9cacb..8f8342e 100644 --- a/tox.ini +++ b/tox.ini @@ -3,6 +3,7 @@ envlist = py3{12,11,10,9,8} minimal style + typing docs skip_missing_interpreters = true @@ -41,6 +42,7 @@ skip_install = true commands = pre-commit autoupdate -j4 [testenv:update-requirements] +base_python = 3.8 labels = update deps = pip-tools skip_install = true