From 0f156ae065fabaf38136e465e102680e4aaf36eb Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 1 Feb 2022 14:47:21 -0800 Subject: [PATCH] torchx/runner: log events to torch.monitor Summary: This logs the `torchx.runner.events.Events` to `torch.monitor` as well as the existing event handlers. Once monitor is stable the existing ones will be removed entirely in favor of the new interface. `torch.monitor` is only available with pytorch 1.11 (or main) so it's a no-op if it's not available. Differential Revision: D33928333 fbshipit-source-id: 632a7e096f46e5c1946932d4a765be179e1e53a8 --- .pyre_configuration | 1 + torchx/runner/events/__init__.py | 9 +++++ torchx/runner/events/api.py | 15 +++++++- torchx/runner/events/test/lib_test.py | 54 +++++++++++++++++++++++++++ typestubs/torch/monitor.pyi | 41 ++++++++++++++++++++ 5 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 typestubs/torch/monitor.pyi diff --git a/.pyre_configuration b/.pyre_configuration index c22415643..d6804849a 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -6,6 +6,7 @@ ".*/IPython/core/tests/nonascii.*" ], "source_directories": [ + "typestubs", "." ], "strict": true, diff --git a/torchx/runner/events/__init__.py b/torchx/runner/events/__init__.py index fbad9d057..0db2aa2ad 100644 --- a/torchx/runner/events/__init__.py +++ b/torchx/runner/events/__init__.py @@ -59,6 +59,15 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger: def record(event: TorchxEvent, destination: str = "null") -> None: _get_or_create_logger(destination).info(event.serialize()) + if destination != "console": + # if using torch>1.11 log the event to torch.monitor + try: + from torch import monitor + + monitor.log_event(event.to_monitor_event()) + except ImportError: + pass + class log_event: """ diff --git a/torchx/runner/events/api.py b/torchx/runner/events/api.py index 81a18b170..07317ddaa 100644 --- a/torchx/runner/events/api.py +++ b/torchx/runner/events/api.py @@ -7,8 +7,12 @@ import json from dataclasses import asdict, dataclass +from datetime import datetime from enum import Enum -from typing import Optional, Union +from typing import Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from torch import monitor class SourceType(str, Enum): @@ -60,3 +64,12 @@ def deserialize(data: Union[str, "TorchxEvent"]) -> "TorchxEvent": def serialize(self) -> str: return json.dumps(asdict(self)) + + def to_monitor_event(self) -> "monitor.Event": + from torch import monitor + + return monitor.Event( + name="torch.runner.Event", + timestamp=datetime.now(), + data={k: v for k, v in self.__dict__.items() if v is not None}, + ) diff --git a/torchx/runner/events/test/lib_test.py b/torchx/runner/events/test/lib_test.py index be67ad10a..b912a90c7 100644 --- a/torchx/runner/events/test/lib_test.py +++ b/torchx/runner/events/test/lib_test.py @@ -8,6 +8,7 @@ import json import logging import unittest +from typing import List from unittest.mock import patch, MagicMock from torchx.runner.events import ( @@ -15,8 +16,16 @@ SourceType, TorchxEvent, log_event, + record, ) +try: + from torch import monitor + + SKIP_MONITOR: bool = False +except ImportError: + SKIP_MONITOR: bool = True + class TorchxEventLibTest(unittest.TestCase): def assert_event( @@ -57,6 +66,51 @@ def test_event_deser(self) -> None: deser_event = TorchxEvent.deserialize(json_event) self.assert_event(event, deser_event) + @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") + def test_monitor(self) -> None: + event = TorchxEvent( + session="test_session", + scheduler="test_scheduler", + api="test_api", + source=SourceType.EXTERNAL, + ) + monitor_event = event.to_monitor_event() + self.assertEqual( + monitor_event.data, + { + "session": "test_session", + "scheduler": "test_scheduler", + "api": "test_api", + "source": "EXTERNAL", + }, + ) + self.assertEqual(monitor_event.name, "torch.runner.Event") + + @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") + @patch("torchx.runner.events._get_or_create_logger") + def test_monitor_record(self, get_logging_handler: MagicMock) -> None: + event = TorchxEvent( + session="test_session", + scheduler="test_scheduler", + api="test_api", + source=SourceType.EXTERNAL, + ) + events: List[monitor.Event] = [] + + def handler(e: monitor.Event) -> None: + events.append(e) + + handle = monitor.register_event_handler(handler) + + try: + record(event) + finally: + monitor.unregister_event_handler(handle) + + self.assertEqual(get_logging_handler.call_count, 1) + self.assertEqual(len(events), 1) + self.assertEqual(events[0].data["session"], "test_session") + @patch("torchx.runner.events.record") class LogEventTest(unittest.TestCase): diff --git a/typestubs/torch/monitor.pyi b/typestubs/torch/monitor.pyi new file mode 100644 index 000000000..7d6eff8b9 --- /dev/null +++ b/typestubs/torch/monitor.pyi @@ -0,0 +1,41 @@ +# Defined in torch/csrc/monitor/python_init.cpp + +from typing import List, Dict, Callable, Union +from enum import Enum +import datetime + +class Aggregation(Enum): + VALUE = "value" + MEAN = "mean" + COUNT = "count" + SUM = "sum" + MAX = "max" + MIN = "min" + +class Stat: + name: str + count: int + def __init__( + self, name: str, aggregations: List[Aggregation], window_size: int, + max_samples: int = -1, + ) -> None: ... + def add(self, v: float) -> None: ... + def get(self) -> Dict[Aggregation, float]: ... + +class Event: + name: str + timestamp: datetime.datetime + data: Dict[str, Union[int, float, bool, str]] + def __init__( + self, + name: str, + timestamp: datetime.datetime, + data: Dict[str, Union[int, float, bool, str]], + ) -> None: ... + +def log_event(e: Event) -> None: ... + +class EventHandlerHandle: ... + +def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ... +def unregister_event_handler(handle: EventHandlerHandle) -> None: ...