Skip to content

Commit 51d370f

Browse files
authored
[doc] Move each profiler to its own file + Add missing PyTorchProfiler to the doc (#7822)
1 parent 6a0d503 commit 51d370f

File tree

9 files changed

+468
-394
lines changed

9 files changed

+468
-394
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
122122
- `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788))
123123

124124

125+
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))
126+
127+
125128
### Deprecated
126129

127130

docs/source/api_references.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,15 @@ Profiler API
137137
.. autosummary::
138138
:toctree: api
139139
:nosignatures:
140+
:template: classtemplate.rst
141+
142+
AbstractProfiler
143+
AdvancedProfiler
144+
BaseProfiler
145+
PassThroughProfiler
146+
PyTorchProfiler
147+
SimpleProfiler
140148

141-
profilers
142149

143150
Trainer API
144151
-----------

pytorch_lightning/profiler/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,16 @@ def custom_processing_step(self, data):
194194
python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))'
195195
196196
"""
197-
198-
from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
197+
from pytorch_lightning.profiler.advanced import AdvancedProfiler
198+
from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler
199199
from pytorch_lightning.profiler.pytorch import PyTorchProfiler
200+
from pytorch_lightning.profiler.simple import SimpleProfiler
200201

201202
__all__ = [
203+
'AbstractProfiler',
202204
'BaseProfiler',
203-
'SimpleProfiler',
204205
'AdvancedProfiler',
205206
'PassThroughProfiler',
206-
"PyTorchProfiler",
207+
'PyTorchProfiler',
208+
'SimpleProfiler',
207209
]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Profiler to check if there are any bottlenecks in your code."""
15+
import cProfile
16+
import io
17+
import logging
18+
import pstats
19+
from pathlib import Path
20+
from typing import Dict, Optional, Union
21+
22+
from pytorch_lightning.profiler.base import BaseProfiler
23+
24+
log = logging.getLogger(__name__)
25+
26+
27+
class AdvancedProfiler(BaseProfiler):
28+
"""
29+
This profiler uses Python's cProfiler to record more detailed information about
30+
time spent in each function call recorded during a given action. The output is quite
31+
verbose and you should only use this if you want very detailed reports.
32+
"""
33+
34+
def __init__(
35+
self,
36+
dirpath: Optional[Union[str, Path]] = None,
37+
filename: Optional[str] = None,
38+
line_count_restriction: float = 1.0,
39+
output_filename: Optional[str] = None,
40+
) -> None:
41+
"""
42+
Args:
43+
dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
44+
``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`)
45+
will be used.
46+
47+
filename: If present, filename where the profiler results will be saved instead of printing to stdout.
48+
The ``.txt`` extension will be used automatically.
49+
50+
line_count_restriction: this can be used to limit the number of functions
51+
reported for each action. either an integer (to select a count of lines),
52+
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
53+
54+
Raises:
55+
ValueError:
56+
If you attempt to stop recording an action which was never started.
57+
"""
58+
super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename)
59+
self.profiled_actions: Dict[str, cProfile.Profile] = {}
60+
self.line_count_restriction = line_count_restriction
61+
62+
def start(self, action_name: str) -> None:
63+
if action_name not in self.profiled_actions:
64+
self.profiled_actions[action_name] = cProfile.Profile()
65+
self.profiled_actions[action_name].enable()
66+
67+
def stop(self, action_name: str) -> None:
68+
pr = self.profiled_actions.get(action_name)
69+
if pr is None:
70+
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
71+
pr.disable()
72+
73+
def summary(self) -> str:
74+
recorded_stats = {}
75+
for action_name, pr in self.profiled_actions.items():
76+
s = io.StringIO()
77+
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
78+
ps.print_stats(self.line_count_restriction)
79+
recorded_stats[action_name] = s.getvalue()
80+
return self._stats_to_str(recorded_stats)
81+
82+
def teardown(self, stage: Optional[str] = None) -> None:
83+
super().teardown(stage=stage)
84+
self.profiled_actions = {}
85+
86+
def __reduce__(self):
87+
# avoids `TypeError: cannot pickle 'cProfile.Profile' object`
88+
return (
89+
self.__class__,
90+
tuple(),
91+
dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction),
92+
)

pytorch_lightning/profiler/base.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Profiler to check if there are any bottlenecks in your code."""
15+
import logging
16+
import os
17+
from abc import ABC, abstractmethod
18+
from contextlib import contextmanager
19+
from pathlib import Path
20+
from typing import Any, Callable, Dict, Optional, TextIO, Union
21+
22+
from pytorch_lightning.utilities import rank_zero_warn
23+
from pytorch_lightning.utilities.cloud_io import get_filesystem
24+
25+
log = logging.getLogger(__name__)
26+
27+
28+
class AbstractProfiler(ABC):
29+
"""Specification of a profiler."""
30+
31+
@abstractmethod
32+
def start(self, action_name: str) -> None:
33+
"""Defines how to start recording an action."""
34+
35+
@abstractmethod
36+
def stop(self, action_name: str) -> None:
37+
"""Defines how to record the duration once an action is complete."""
38+
39+
@abstractmethod
40+
def summary(self) -> str:
41+
"""Create profiler summary in text format."""
42+
43+
@abstractmethod
44+
def setup(self, **kwargs: Any) -> None:
45+
"""Execute arbitrary pre-profiling set-up steps as defined by subclass."""
46+
47+
@abstractmethod
48+
def teardown(self, **kwargs: Any) -> None:
49+
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
50+
51+
52+
class BaseProfiler(AbstractProfiler):
53+
"""
54+
If you wish to write a custom profiler, you should inherit from this class.
55+
"""
56+
57+
def __init__(
58+
self,
59+
dirpath: Optional[Union[str, Path]] = None,
60+
filename: Optional[str] = None,
61+
output_filename: Optional[str] = None,
62+
) -> None:
63+
self.dirpath = dirpath
64+
self.filename = filename
65+
if output_filename is not None:
66+
rank_zero_warn(
67+
"`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in"
68+
" favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5",
69+
DeprecationWarning
70+
)
71+
filepath = Path(output_filename)
72+
self.dirpath = filepath.parent
73+
self.filename = filepath.stem
74+
75+
self._output_file: Optional[TextIO] = None
76+
self._write_stream: Optional[Callable] = None
77+
self._local_rank: Optional[int] = None
78+
self._log_dir: Optional[str] = None
79+
self._stage: Optional[str] = None
80+
81+
@contextmanager
82+
def profile(self, action_name: str) -> None:
83+
"""
84+
Yields a context manager to encapsulate the scope of a profiled action.
85+
86+
Example::
87+
88+
with self.profile('load training data'):
89+
# load training data code
90+
91+
The profiler will start once you've entered the context and will automatically
92+
stop once you exit the code block.
93+
"""
94+
try:
95+
self.start(action_name)
96+
yield action_name
97+
finally:
98+
self.stop(action_name)
99+
100+
def profile_iterable(self, iterable, action_name: str) -> None:
101+
iterator = iter(iterable)
102+
while True:
103+
try:
104+
self.start(action_name)
105+
value = next(iterator)
106+
self.stop(action_name)
107+
yield value
108+
except StopIteration:
109+
self.stop(action_name)
110+
break
111+
112+
def _rank_zero_info(self, *args, **kwargs) -> None:
113+
if self._local_rank in (None, 0):
114+
log.info(*args, **kwargs)
115+
116+
def _prepare_filename(self, extension: str = ".txt") -> str:
117+
filename = ""
118+
if self._stage is not None:
119+
filename += f"{self._stage}-"
120+
filename += str(self.filename)
121+
if self._local_rank is not None:
122+
filename += f"-{self._local_rank}"
123+
filename += extension
124+
return filename
125+
126+
def _prepare_streams(self) -> None:
127+
if self._write_stream is not None:
128+
return
129+
if self.filename:
130+
filepath = os.path.join(self.dirpath, self._prepare_filename())
131+
fs = get_filesystem(filepath)
132+
file = fs.open(filepath, "a")
133+
self._output_file = file
134+
self._write_stream = file.write
135+
else:
136+
self._write_stream = self._rank_zero_info
137+
138+
def describe(self) -> None:
139+
"""Logs a profile report after the conclusion of run."""
140+
# there are pickling issues with open file handles in Python 3.6
141+
# so to avoid them, we open and close the files within this function
142+
# by calling `_prepare_streams` and `teardown`
143+
self._prepare_streams()
144+
summary = self.summary()
145+
if summary:
146+
self._write_stream(summary)
147+
if self._output_file is not None:
148+
self._output_file.flush()
149+
self.teardown(stage=self._stage)
150+
151+
def _stats_to_str(self, stats: Dict[str, str]) -> str:
152+
stage = f"{self._stage.upper()} " if self._stage is not None else ""
153+
output = [stage + "Profiler Report"]
154+
for action, value in stats.items():
155+
header = f"Profile stats for: {action}"
156+
if self._local_rank is not None:
157+
header += f" rank: {self._local_rank}"
158+
output.append(header)
159+
output.append(value)
160+
return os.linesep.join(output)
161+
162+
def setup(
163+
self,
164+
stage: Optional[str] = None,
165+
local_rank: Optional[int] = None,
166+
log_dir: Optional[str] = None,
167+
) -> None:
168+
"""Execute arbitrary pre-profiling set-up steps."""
169+
self._stage = stage
170+
self._local_rank = local_rank
171+
self._log_dir = log_dir
172+
self.dirpath = self.dirpath or log_dir
173+
174+
def teardown(self, stage: Optional[str] = None) -> None:
175+
"""
176+
Execute arbitrary post-profiling tear-down steps.
177+
178+
Closes the currently open file and stream.
179+
"""
180+
self._write_stream = None
181+
if self._output_file is not None:
182+
self._output_file.close()
183+
self._output_file = None # can't pickle TextIOWrapper
184+
185+
def __del__(self) -> None:
186+
self.teardown(stage=self._stage)
187+
188+
def start(self, action_name: str) -> None:
189+
raise NotImplementedError
190+
191+
def stop(self, action_name: str) -> None:
192+
raise NotImplementedError
193+
194+
def summary(self) -> str:
195+
raise NotImplementedError
196+
197+
@property
198+
def local_rank(self) -> int:
199+
return 0 if self._local_rank is None else self._local_rank
200+
201+
202+
class PassThroughProfiler(BaseProfiler):
203+
"""
204+
This class should be used when you don't want the (small) overhead of profiling.
205+
The Trainer uses this class by default.
206+
"""
207+
208+
def start(self, action_name: str) -> None:
209+
pass
210+
211+
def stop(self, action_name: str) -> None:
212+
pass
213+
214+
def summary(self) -> str:
215+
return ""

0 commit comments

Comments
 (0)