Skip to content

Commit

Permalink
refactor: create an abstraction for each instrument
Browse files Browse the repository at this point in the history
  • Loading branch information
art049 committed Sep 10, 2024
1 parent 88ca4f6 commit 048a190
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 91 deletions.
44 changes: 44 additions & 0 deletions src/pytest_codspeed/instruments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Callable, ParamSpec, TypeVar

import pytest

T = TypeVar("T")
P = ParamSpec("P")


class Instrument(metaclass=ABCMeta):
@abstractmethod
def __init__(self): ...

@abstractmethod
def get_instrument_config_str_and_warns(self) -> tuple[str, list[str]]: ...

@abstractmethod
def measure(
self,
name: str,
uri: str,
fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T: ...

@abstractmethod
def report(self, session: pytest.Session) -> None: ...


class CodSpeedMeasurementMode(str, Enum):
Instrumentation = "instrumentation"


def get_instrument_from_mode(mode: CodSpeedMeasurementMode) -> type[Instrument]:
from pytest_codspeed.instruments.instrumentation import InstrumentationInstrument

return InstrumentationInstrument
84 changes: 84 additions & 0 deletions src/pytest_codspeed/instruments/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import os
import sys
from typing import TYPE_CHECKING

from pytest_codspeed import __version__
from pytest_codspeed._wrapper import get_lib
from pytest_codspeed.instruments import Instrument

if TYPE_CHECKING:
from typing import Callable

from pytest import Session

from pytest_codspeed.instruments import P, T

SUPPORTS_PERF_TRAMPOLINE = sys.version_info >= (3, 12)


class InstrumentationInstrument(Instrument):
def __init__(self):
self.benchmark_count = 0
self.should_measure = os.environ.get("CODSPEED_ENV") is not None
if self.should_measure:
self.lib = get_lib()
self.lib.dump_stats_at(
f"Metadata: pytest-codspeed {__version__}".encode("ascii")
)
if SUPPORTS_PERF_TRAMPOLINE:
sys.activate_stack_trampoline("perf")
else:
self.lib = None

def get_instrument_config_str_and_warns(self) -> tuple[str, list[str]]:
config = (
f"mode: instrumentation, "
f"callgraph: {'enabled' if SUPPORTS_PERF_TRAMPOLINE else 'not supported'}"
)
warnings = []
if not self.should_measure:
warnings.append(
"\033[1m"
"NOTICE: codspeed is enabled, but no performance measurement"
" will be made since it's running in an unknown environment."
"\033[0m"
)
return config, warnings

def measure(
self,
name: str,
uri: str,
fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
self.benchmark_count += 1
if self.lib is None: # Thus should_measure is False
return fn(*args, **kwargs)

def __codspeed_root_frame__() -> T:
return fn(*args, **kwargs)

if SUPPORTS_PERF_TRAMPOLINE:
# Warmup CPython performance map cache
__codspeed_root_frame__()

self.lib.zero_stats()
self.lib.start_instrumentation()
try:
return __codspeed_root_frame__()
finally:
# Ensure instrumentation is stopped even if the test failed
self.lib.stop_instrumentation()
self.lib.dump_stats_at(uri.encode("ascii"))

def report(self, session: Session) -> None:
reporter = session.config.pluginmanager.get_plugin("terminalreporter")
count_suffix = "benchmarked" if self.should_measure else "benchmark tested"
reporter.write_sep(
"=",
f"{self.benchmark_count} {count_suffix}",
)
113 changes: 34 additions & 79 deletions src/pytest_codspeed/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@
import gc
import os
import pkgutil
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import pytest
from _pytest.fixtures import FixtureManager

from pytest_codspeed.utils import get_git_relative_uri
from pytest_codspeed.instruments import (
CodSpeedMeasurementMode,
get_instrument_from_mode,
)
from pytest_codspeed.utils import get_git_relative_uri_and_name

from . import __version__
from ._wrapper import get_lib

if TYPE_CHECKING:
from typing import Any, Callable, ParamSpec, TypeVar

from ._wrapper import LibType
from pytest_codspeed.instruments import Instrument

T = TypeVar("T")
P = ParamSpec("P")

IS_PYTEST_BENCHMARK_INSTALLED = pkgutil.find_loader("pytest_benchmark") is not None
SUPPORTS_PERF_TRAMPOLINE = sys.version_info >= (3, 12)
BEFORE_PYTEST_8_1_1 = pytest.version_tuple < (8, 1, 1)


Expand All @@ -43,8 +44,8 @@ def pytest_addoption(parser: pytest.Parser):
@dataclass(unsafe_hash=True)
class CodSpeedPlugin:
is_codspeed_enabled: bool
should_measure: bool
lib: LibType | None
mode: CodSpeedMeasurementMode
instrument: Instrument
disabled_plugins: tuple[str, ...]
benchmark_count: int = field(default=0, hash=False, compare=False)

Expand All @@ -67,10 +68,8 @@ def pytest_configure(config: pytest.Config):
is_codspeed_enabled = (
config.getoption("--codspeed") or os.environ.get("CODSPEED_ENV") is not None
)
should_measure = os.environ.get("CODSPEED_ENV") is not None
lib = get_lib() if should_measure else None
if lib is not None:
lib.dump_stats_at(f"Metadata: pytest-codspeed {__version__}".encode("ascii"))
mode = CodSpeedMeasurementMode.Instrumentation
instrument = get_instrument_from_mode(mode)
disabled_plugins: list[str] = []
# Disable pytest-benchmark if codspeed is enabled
if is_codspeed_enabled and IS_PYTEST_BENCHMARK_INSTALLED:
Expand All @@ -79,14 +78,15 @@ def pytest_configure(config: pytest.Config):
disabled_plugins.append("pytest-benchmark")

plugin = CodSpeedPlugin(
is_codspeed_enabled=is_codspeed_enabled,
should_measure=should_measure,
lib=lib,
disabled_plugins=tuple(disabled_plugins),
is_codspeed_enabled=is_codspeed_enabled,
mode=mode,
instrument=instrument(),
)
config.pluginmanager.register(plugin, PLUGIN_NAME)


@pytest.hookimpl()
def pytest_plugin_registered(plugin, manager: pytest.PytestPluginManager):
"""Patch the benchmark fixture to use the codspeed one if codspeed is enabled"""
if IS_PYTEST_BENCHMARK_INSTALLED and isinstance(plugin, FixtureManager):
Expand All @@ -110,18 +110,9 @@ def pytest_plugin_registered(plugin, manager: pytest.PytestPluginManager):

@pytest.hookimpl(trylast=True)
def pytest_report_header(config: pytest.Config):
out = [
f"codspeed: {__version__} "
f"(callgraph: {'enabled' if SUPPORTS_PERF_TRAMPOLINE else 'not supported'})"
]
plugin = get_plugin(config)
if plugin.is_codspeed_enabled and not plugin.should_measure:
out.append(
"\033[1m"
"NOTICE: codspeed is enabled, but no performance measurement"
" will be made since it's running in an unknown environment."
"\033[0m"
)
config_str, warns = plugin.instrument.get_instrument_config_str_and_warns()
out = [f"codspeed: {__version__} ({config_str})", *warns]
if len(plugin.disabled_plugins) > 0:
out.append(
"\033[93mCodSpeed had to disable the following plugins: "
Expand All @@ -146,19 +137,11 @@ def should_benchmark_item(item: pytest.Item) -> bool:
return has_benchmark_fixture(item) or has_benchmark_marker(item)


@pytest.hookimpl()
def pytest_sessionstart(session: pytest.Session):
plugin = get_plugin(session.config)
if plugin.is_codspeed_enabled:
plugin.benchmark_count = 0
if plugin.should_measure and SUPPORTS_PERF_TRAMPOLINE:
sys.activate_stack_trampoline("perf") # type: ignore


@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(
session: pytest.Session, config: pytest.Config, items: list[pytest.Item]
):
"""Filter out items that should not be benchmarked when codspeed is enabled"""
plugin = get_plugin(config)
if plugin.is_codspeed_enabled:
deselected = []
Expand All @@ -172,8 +155,8 @@ def pytest_collection_modifyitems(
items[:] = selected


def _run_with_instrumentation(
lib: LibType,
def _measure(
plugin: CodSpeedPlugin,
nodeid: str,
config: pytest.Config,
fn: Callable[P, T],
Expand All @@ -184,39 +167,24 @@ def _run_with_instrumentation(
if is_gc_enabled:
gc.collect()
gc.disable()

def __codspeed_root_frame__() -> T:
return fn(*args, **kwargs)

try:
if SUPPORTS_PERF_TRAMPOLINE:
# Warmup CPython performance map cache
__codspeed_root_frame__()

lib.zero_stats()
lib.start_instrumentation()
try:
return __codspeed_root_frame__()
finally:
# Ensure instrumentation is stopped even if the test failed
lib.stop_instrumentation()
uri = get_git_relative_uri(nodeid, config.rootpath)
lib.dump_stats_at(uri.encode("ascii"))
uri, name = get_git_relative_uri_and_name(nodeid, config.rootpath)
return plugin.instrument.measure(name, uri, fn, *args, **kwargs)
finally:
# Ensure GC is re-enabled even if the test failed
if is_gc_enabled:
gc.enable()


def wrap_runtest(
lib: LibType,
plugin: CodSpeedPlugin,
nodeid: str,
config: pytest.Config,
fn: Callable[P, T],
) -> Callable[P, T]:
@functools.wraps(fn)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
return _run_with_instrumentation(lib, nodeid, config, fn, *args, **kwargs)
return _measure(plugin, nodeid, config, fn, *args, **kwargs)

return wrapped

Expand All @@ -232,17 +200,18 @@ def pytest_runtest_protocol(item: pytest.Item, nextitem: pytest.Item | None):
# Instrumentation is handled by the fixture
return None

plugin.benchmark_count += 1
if not plugin.should_measure:
# Benchmark counted but will be run in the default protocol
return None

# Wrap runtest and defer to default protocol
assert plugin.lib is not None
item.runtest = wrap_runtest(plugin.lib, item.nodeid, item.config, item.runtest)
item.runtest = wrap_runtest(plugin, item.nodeid, item.config, item.runtest)
return None


@pytest.hookimpl()
def pytest_sessionfinish(session: pytest.Session, exitstatus):
plugin = get_plugin(session.config)
if plugin.is_codspeed_enabled:
plugin.instrument.report(session)


class BenchmarkFixture:
"""The fixture that can be used to benchmark a function."""

Expand All @@ -254,11 +223,9 @@ def __init__(self, request: pytest.FixtureRequest):
def __call__(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
config = self._request.config
plugin = get_plugin(config)
plugin.benchmark_count += 1
if plugin.is_codspeed_enabled and plugin.should_measure:
assert plugin.lib is not None
return _run_with_instrumentation(
plugin.lib, self._request.node.nodeid, config, func, *args, **kwargs
if plugin.is_codspeed_enabled:
return _measure(
plugin, self._request.node.nodeid, config, func, *args, **kwargs
)
else:
return func(*args, **kwargs)
Expand All @@ -277,15 +244,3 @@ def benchmark(codspeed_benchmark, request: pytest.FixtureRequest):
Compatibility with pytest-benchmark
"""
return codspeed_benchmark


@pytest.hookimpl()
def pytest_sessionfinish(session: pytest.Session, exitstatus):
plugin = get_plugin(session.config)
if plugin.is_codspeed_enabled:
reporter = session.config.pluginmanager.get_plugin("terminalreporter")
count_suffix = "benchmarked" if plugin.should_measure else "benchmark tested"
reporter.write_sep(
"=",
f"{plugin.benchmark_count} {count_suffix}",
)
12 changes: 7 additions & 5 deletions src/pytest_codspeed/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pathlib import Path


Expand All @@ -15,11 +17,11 @@ def get_git_relative_path(abs_path: Path) -> Path:
return abs_path


def get_git_relative_uri(uri: str, pytest_rootdir: Path) -> str:
"""Get the benchmark uri relative to the git root dir.
def get_git_relative_uri_and_name(nodeid: str, pytest_rootdir: Path) -> tuple[str, str]:
"""Get the benchmark uri relative to the git root dir and the benchmark name.
Args:
uri (str): the benchmark uri, for example:
nodeid (str): the pytest nodeid, for example:
testing/test_excinfo.py::TestFormattedExcinfo::test_repr_source
pytest_rootdir (str): the pytest root dir, for example:
/home/user/gitrepo/folder
Expand All @@ -29,7 +31,7 @@ def get_git_relative_uri(uri: str, pytest_rootdir: Path) -> str:
folder/testing/test_excinfo.py::TestFormattedExcinfo::test_repr_source
"""
file_path, function_path = uri.split("::", 1)
file_path, bench_name = nodeid.split("::", 1)
absolute_file_path = pytest_rootdir / Path(file_path)
relative_git_path = get_git_relative_path(absolute_file_path)
return f"{str(relative_git_path)}::{function_path}"
return (f"{str(relative_git_path)}::{bench_name}", bench_name)
Loading

0 comments on commit 048a190

Please sign in to comment.