Skip to content

Commit

Permalink
Resolve schedule step bug for PyTorch Profiler (#6674)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
tchaton and carmocca authored Mar 25, 2021
1 parent 217c12a commit 0ea8f39
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
54 changes: 28 additions & 26 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def __init__(self, schedule: Callable) -> None:
if not _KINETO_AVAILABLE:
raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.")
self._schedule = schedule
self.reset()

def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name

def pre_step(self, current_action: str) -> None:
self._current_action = current_action

def reset(self):
self._num_training_step_and_backward = 0
self._num_validation_step = 0
self._num_test_step = 0
Expand All @@ -119,12 +128,6 @@ def __init__(self, schedule: Callable) -> None:
self._current_action: Optional[str] = None
self._start_action_name: Optional[str] = None

def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name

def pre_step(self, current_action: str) -> None:
self._current_action = current_action

@property
def num_step(self) -> int:
if self._current_action == "training_step_and_backward":
Expand All @@ -142,8 +145,9 @@ def _step(self) -> None:
if self._current_action == "training_step_and_backward":
self._num_training_step_and_backward += 1
elif self._current_action == "validation_step":
if self._start_action_name == "on_train_start" and self._num_training_step_and_backward > 0:
self._num_validation_step += 1
if self._start_action_name == "on_fit_start":
if self._num_training_step_and_backward > 0:
self._num_validation_step += 1
else:
self._num_validation_step += 1
elif self._current_action == "test_step":
Expand Down Expand Up @@ -210,7 +214,7 @@ class PyTorchProfiler(BaseProfiler):
"count",
}
START_RECORD_FUNCTIONS = {
'on_train_start',
'on_fit_start',
'on_validation_start',
'on_test_start',
'on_predict_start',
Expand Down Expand Up @@ -289,8 +293,9 @@ def __init__(
self._export_to_chrome = export_to_chrome
self._row_limit = row_limit
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
self._record_functions_start = record_functions | self.START_RECORD_FUNCTIONS
self._record_functions = record_functions | self.RECORD_FUNCTIONS
self._user_record_functions = record_functions
self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS
self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS
self._record_module_names = record_module_names
self._profiler_kwargs = profiler_kwargs

Expand All @@ -304,14 +309,14 @@ def __init__(
self._schedule: Optional[ScheduleWrapper] = None

if _KINETO_AVAILABLE:
self.__init_kineto__(profiler_kwargs)
self._init_kineto(profiler_kwargs)

if self._sort_by_key not in self.AVAILABLE_SORT_KEYS:
raise MisconfigurationException(
f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
)

def __init_kineto__(self, profiler_kwargs: Any):
def _init_kineto(self, profiler_kwargs: Any) -> None:
has_schedule = "schedule" in profiler_kwargs
self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs

Expand Down Expand Up @@ -362,7 +367,7 @@ def __deprecation_check(
def _default_schedule() -> Optional[callable]:
if _KINETO_AVAILABLE:
# Those schedule defaults allow the profiling overhead to be negligible over training time.
return torch.profiler.schedule(wait=1, warmup=1, active=2)
return torch.profiler.schedule(wait=1, warmup=1, active=3)

def _default_activities(self) -> List['ProfilerActivity']:
activities = []
Expand All @@ -374,10 +379,6 @@ def _default_activities(self) -> List['ProfilerActivity']:
activities.append(ProfilerActivity.CUDA)
return activities

@property
def step_action_names(self) -> Set[str]:
return self.STEP_FUNCTIONS | self._record_functions

def start(self, action_name: str) -> None:
if self.profiler is None and action_name in self._record_functions_start:

Expand Down Expand Up @@ -411,9 +412,6 @@ def start(self, action_name: str) -> None:
recording.__enter__()
self._recording_map[action_name] = recording

if self._schedule is not None:
self._schedule.pre_step(action_name)

def stop(self, action_name: str) -> None:
if action_name in self._recording_map:
self._recording_map[action_name].__exit__(None, None, None)
Expand All @@ -422,16 +420,14 @@ def stop(self, action_name: str) -> None:
if not _KINETO_AVAILABLE or self._emit_nvtx:
return

if action_name in self.step_action_names:
if self.profiler is not None and action_name in self.STEP_FUNCTIONS:
if self._schedule is not None:
self._schedule._current_action = action_name
self._schedule.pre_step(action_name)

def on_trace_ready(profiler):
filename = f"{action_name}_{self.local_rank}"

if self.dirpath is not None:
if self._export_to_chrome:
handler = tensorboard_trace_handler(self.dirpath, filename)
handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension=""))
handler(profiler)

if self._export_to_flame_graph:
Expand All @@ -442,6 +438,9 @@ def on_trace_ready(profiler):

if not self._has_on_trace_ready:
self.profiler.on_trace_ready = on_trace_ready

if self._schedule is not None:
self.profiler.step_num = self._schedule.num_step
self.profiler.step()

def summary(self) -> str:
Expand Down Expand Up @@ -492,6 +491,9 @@ def _delete_profilers(self) -> None:
self._cache_functions_events()
self.profiler = None

if self._schedule is not None:
self._schedule.reset()

if self._parent_profiler is not None:
self._parent_profiler.__exit__(None, None, None)
self._parent_profiler = None
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]):
def setup(self) -> None:
trainer = self.trainer
local_rank = trainer.local_rank if trainer.world_size > 1 else None
trainer.profiler.lightning_module = proxy(trainer.lightning_module)
trainer.profiler._lightning_module = proxy(trainer.lightning_module)
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)
1 change: 1 addition & 0 deletions requirements/adjust_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"1.7.0": dict(torchvision="0.8.1", torchtext="0.8"),
"1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"),
"1.8.0": dict(torchvision="0.9.0", torchtext="0.9"),
"1.8.1": dict(torchvision="0.9.0", torchtext="0.9"),
}


Expand Down

0 comments on commit 0ea8f39

Please sign in to comment.