Skip to content

Commit

Permalink
Refactor PyTorch profiler 4/5 (#6349)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
carmocca and tchaton authored Mar 23, 2021
1 parent 3cf0c31 commit 51b10f7
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 219 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))


- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))


Expand All @@ -72,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))


- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
Expand All @@ -83,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))


- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),

[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),
Expand Down
12 changes: 5 additions & 7 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,15 @@ def _prepare_filename(self) -> str:
filename += f"{self._stage}-"
filename += str(self.filename)
if self._local_rank is not None:
filename += f"-{self.local_rank}"
filename += f"-{self._local_rank}"
filename += ".txt"
return filename

def _prepare_streams(self) -> None:
if self._write_stream is not None:
return
if self.filename:
dirpath = self.dirpath or self._log_dir
filepath = os.path.join(dirpath, self._prepare_filename())
filepath = os.path.join(self.dirpath, self._prepare_filename())
fs = get_filesystem(filepath)
file = fs.open(filepath, "a")
self._output_file = file
Expand Down Expand Up @@ -175,8 +174,7 @@ def setup(
self._stage = stage
self._local_rank = local_rank
self._log_dir = log_dir
if self.dirpath is None:
self.dirpath = self._log_dir
self.dirpath = self.dirpath or log_dir

def teardown(self, stage: Optional[str] = None) -> None:
"""
Expand All @@ -202,8 +200,8 @@ def summary(self) -> str:
raise NotImplementedError

@property
def local_rank(self):
return '0' if self._local_rank is None else self._local_rank
def local_rank(self) -> int:
return 0 if self._local_rank is None else self._local_rank


class PassThroughProfiler(BaseProfiler):
Expand Down
Loading

0 comments on commit 51b10f7

Please sign in to comment.