-
Notifications
You must be signed in to change notification settings - Fork 54
/
base.py
70 lines (54 loc) · 2.2 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from dataclasses import dataclass
from logging import getLogger
from typing import Type
from hydra.utils import get_class
from .backends.base import Backend
from .backends.config import BackendConfig
from .config import BenchmarkConfig
from .hub_utils import PushToHubMixin, classproperty
from .launchers import LauncherConfig
from .launchers.base import Launcher
from .report import BenchmarkReport
from .scenarios import ScenarioConfig
from .scenarios.base import Scenario
LOGGER = getLogger("benchmark")
@dataclass
class Benchmark(PushToHubMixin):
config: BenchmarkConfig
report: BenchmarkReport
def __post_init__(self):
if isinstance(self.config, dict):
self.config = BenchmarkConfig(**self.config)
elif not isinstance(self.config, BenchmarkConfig):
raise ValueError("config must be either a dict or a BenchmarkConfig instance")
@classmethod
def launch(cls, config: BenchmarkConfig):
"""
Runs an benchmark using specified launcher configuration/logic
"""
# Allocate requested launcher
launcher_config: LauncherConfig = config.launcher
launcher_factory: Type[Launcher] = get_class(launcher_config._target_)
launcher: Launcher = launcher_factory(launcher_config)
# Launch the benchmark using the launcher
report = launcher.launch(worker=cls.run, worker_args=[config])
return report
@classmethod
def run(cls, config: BenchmarkConfig):
"""
Runs a scenario using specified backend configuration/logic
"""
# Allocate requested backend
backend_config: BackendConfig = config.backend
backend_factory: Type[Backend] = get_class(backend_config._target_)
backend: Backend = backend_factory(backend_config)
# Allocate requested scenario
scenario_config: ScenarioConfig = config.scenario
scenario_factory: Type[Scenario] = get_class(scenario_config._target_)
scenario: Scenario = scenario_factory(scenario_config)
# Run the scenario using the backend
report = scenario.run(backend)
return report
@classproperty
def default_filename(cls) -> str:
return "benchmark.json"