diff --git a/ddtrace/profiling/collector/pytorch.py b/ddtrace/profiling/collector/pytorch.py index 6d9556f7b6..9a8a7ff2de 100644 --- a/ddtrace/profiling/collector/pytorch.py +++ b/ddtrace/profiling/collector/pytorch.py @@ -61,12 +61,12 @@ def _start_service(self): raise collector.CollectorUnavailable(e) self._torch_module = torch self.patch() - super(MLProfilerCollector, self)._start_service() + super()._start_service() def _stop_service(self): # type: (...) -> None """Stop collecting framework profiler usage.""" - super(MLProfilerCollector, self)._stop_service() + super()._stop_service() self.unpatch() def patch(self): @@ -96,6 +96,9 @@ class TorchProfilerCollector(MLProfilerCollector): PROFILED_TORCH_CLASS = _WrappedTorchProfiler + def __init__(self, recorder=None): + super().__init__(recorder) + def _get_patch_target(self): # type: (...) -> typing.Any return self._torch_module.profiler.profile