Skip to content

Commit

Permalink
Merge pull request #1556 from nijkah/support_old_torch
Browse files Browse the repository at this point in the history
Add an exception for torch<1.8.1
  • Loading branch information
msaroufim authored Apr 16, 2022
2 parents 0244694 + 5eda5d7 commit 6c13b62
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@
import importlib.util
import time
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from pkg_resources import packaging
from ..utils.util import list_classes_from_module, load_label_mapping

if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.8.1"):
from torch.profiler import profile, record_function, ProfilerActivity
PROFILER_AVAILABLE = True
else:
PROFILER_AVAILABLE = False


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -211,7 +217,11 @@ def handle(self, data, context):

is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None)
if is_profiler_enabled:
output, _ = self._infer_with_profiler(data=data)
if PROFILER_AVAILABLE:
output, _ = self._infer_with_profiler(data=data)
else:
raise RuntimeError("Profiler is enabled but current version of torch does not support."
"Install torch>=1.8.1 to use profiler.")
else:
if self._is_describe():
output = [self.describe_handle()]
Expand Down

0 comments on commit 6c13b62

Please sign in to comment.