-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: create an abstraction for each instrument
- Loading branch information
Showing
6 changed files
with
181 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABCMeta, abstractmethod | ||
from enum import Enum | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from typing import Callable, ParamSpec, TypeVar | ||
|
||
import pytest | ||
|
||
T = TypeVar("T") | ||
P = ParamSpec("P") | ||
|
||
|
||
class Instrument(metaclass=ABCMeta): | ||
@abstractmethod | ||
def __init__(self): ... | ||
|
||
@abstractmethod | ||
def get_instrument_config_str_and_warns(self) -> tuple[str, list[str]]: ... | ||
|
||
@abstractmethod | ||
def measure( | ||
self, | ||
name: str, | ||
uri: str, | ||
fn: Callable[P, T], | ||
*args: P.args, | ||
**kwargs: P.kwargs, | ||
) -> T: ... | ||
|
||
@abstractmethod | ||
def report(self, session: pytest.Session) -> None: ... | ||
|
||
|
||
class CodSpeedMeasurementMode(str, Enum): | ||
Instrumentation = "instrumentation" | ||
|
||
|
||
def get_instrument_from_mode(mode: CodSpeedMeasurementMode) -> type[Instrument]: | ||
from pytest_codspeed.instruments.instrumentation import InstrumentationInstrument | ||
|
||
return InstrumentationInstrument |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import sys | ||
from typing import TYPE_CHECKING | ||
|
||
from pytest_codspeed import __version__ | ||
from pytest_codspeed._wrapper import get_lib | ||
from pytest_codspeed.instruments import Instrument | ||
|
||
if TYPE_CHECKING: | ||
from typing import Callable | ||
|
||
from pytest import Session | ||
|
||
from pytest_codspeed.instruments import P, T | ||
|
||
SUPPORTS_PERF_TRAMPOLINE = sys.version_info >= (3, 12) | ||
|
||
|
||
class InstrumentationInstrument(Instrument): | ||
def __init__(self): | ||
self.benchmark_count = 0 | ||
self.should_measure = os.environ.get("CODSPEED_ENV") is not None | ||
if self.should_measure: | ||
self.lib = get_lib() | ||
self.lib.dump_stats_at( | ||
f"Metadata: pytest-codspeed {__version__}".encode("ascii") | ||
) | ||
if SUPPORTS_PERF_TRAMPOLINE: | ||
sys.activate_stack_trampoline("perf") | ||
else: | ||
self.lib = None | ||
|
||
def get_instrument_config_str_and_warns(self) -> tuple[str, list[str]]: | ||
config = ( | ||
f"mode: instrumentation, " | ||
f"callgraph: {'enabled' if SUPPORTS_PERF_TRAMPOLINE else 'not supported'}" | ||
) | ||
warnings = [] | ||
if not self.should_measure: | ||
warnings.append( | ||
"\033[1m" | ||
"NOTICE: codspeed is enabled, but no performance measurement" | ||
" will be made since it's running in an unknown environment." | ||
"\033[0m" | ||
) | ||
return config, warnings | ||
|
||
def measure( | ||
self, | ||
name: str, | ||
uri: str, | ||
fn: Callable[P, T], | ||
*args: P.args, | ||
**kwargs: P.kwargs, | ||
) -> T: | ||
self.benchmark_count += 1 | ||
if self.lib is None: # Thus should_measure is False | ||
return fn(*args, **kwargs) | ||
|
||
def __codspeed_root_frame__() -> T: | ||
return fn(*args, **kwargs) | ||
|
||
if SUPPORTS_PERF_TRAMPOLINE: | ||
# Warmup CPython performance map cache | ||
__codspeed_root_frame__() | ||
|
||
self.lib.zero_stats() | ||
self.lib.start_instrumentation() | ||
try: | ||
return __codspeed_root_frame__() | ||
finally: | ||
# Ensure instrumentation is stopped even if the test failed | ||
self.lib.stop_instrumentation() | ||
self.lib.dump_stats_at(uri.encode("ascii")) | ||
|
||
def report(self, session: Session) -> None: | ||
reporter = session.config.pluginmanager.get_plugin("terminalreporter") | ||
count_suffix = "benchmarked" if self.should_measure else "benchmark tested" | ||
reporter.write_sep( | ||
"=", | ||
f"{self.benchmark_count} {count_suffix}", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.