|  | 
| 8 | 8 | import json | 
| 9 | 9 | import logging | 
| 10 | 10 | import unittest | 
|  | 11 | +from typing import List | 
| 11 | 12 | from unittest.mock import patch, MagicMock | 
| 12 | 13 | 
 | 
| 13 | 14 | from torchx.runner.events import ( | 
| 14 | 15 |     _get_or_create_logger, | 
| 15 | 16 |     SourceType, | 
| 16 | 17 |     TorchxEvent, | 
| 17 | 18 |     log_event, | 
|  | 19 | +    record, | 
| 18 | 20 | ) | 
| 19 | 21 | 
 | 
|  | 22 | +try: | 
|  | 23 | +    from torch import monitor | 
|  | 24 | + | 
|  | 25 | +    SKIP_MONITOR: bool = False | 
|  | 26 | +except ImportError: | 
|  | 27 | +    SKIP_MONITOR: bool = True | 
|  | 28 | + | 
| 20 | 29 | 
 | 
| 21 | 30 | class TorchxEventLibTest(unittest.TestCase): | 
| 22 | 31 |     def assert_event( | 
| @@ -57,6 +66,51 @@ def test_event_deser(self) -> None: | 
| 57 | 66 |         deser_event = TorchxEvent.deserialize(json_event) | 
| 58 | 67 |         self.assert_event(event, deser_event) | 
| 59 | 68 | 
 | 
|  | 69 | +    @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") | 
|  | 70 | +    def test_monitor(self) -> None: | 
|  | 71 | +        event = TorchxEvent( | 
|  | 72 | +            session="test_session", | 
|  | 73 | +            scheduler="test_scheduler", | 
|  | 74 | +            api="test_api", | 
|  | 75 | +            source=SourceType.EXTERNAL, | 
|  | 76 | +        ) | 
|  | 77 | +        monitor_event = event.to_monitor_event() | 
|  | 78 | +        self.assertEqual( | 
|  | 79 | +            monitor_event.data, | 
|  | 80 | +            { | 
|  | 81 | +                "session": "test_session", | 
|  | 82 | +                "scheduler": "test_scheduler", | 
|  | 83 | +                "api": "test_api", | 
|  | 84 | +                "source": "EXTERNAL", | 
|  | 85 | +            }, | 
|  | 86 | +        ) | 
|  | 87 | +        self.assertEqual(monitor_event.name, "torch.runner.Event") | 
|  | 88 | + | 
|  | 89 | +    @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") | 
|  | 90 | +    @patch("torchx.runner.events._get_or_create_logger") | 
|  | 91 | +    def test_monitor_record(self, get_logging_handler: MagicMock) -> None: | 
|  | 92 | +        event = TorchxEvent( | 
|  | 93 | +            session="test_session", | 
|  | 94 | +            scheduler="test_scheduler", | 
|  | 95 | +            api="test_api", | 
|  | 96 | +            source=SourceType.EXTERNAL, | 
|  | 97 | +        ) | 
|  | 98 | +        events: List[monitor.Event] = [] | 
|  | 99 | + | 
|  | 100 | +        def handler(e: monitor.Event) -> None: | 
|  | 101 | +            events.append(e) | 
|  | 102 | + | 
|  | 103 | +        handle = monitor.register_event_handler(handler) | 
|  | 104 | + | 
|  | 105 | +        try: | 
|  | 106 | +            record(event) | 
|  | 107 | +        finally: | 
|  | 108 | +            monitor.unregister_event_handler(handle) | 
|  | 109 | + | 
|  | 110 | +        self.assertEqual(get_logging_handler.call_count, 1) | 
|  | 111 | +        self.assertEqual(len(events), 1) | 
|  | 112 | +        self.assertEqual(events[0].data["session"], "test_session") | 
|  | 113 | + | 
| 60 | 114 | 
 | 
| 61 | 115 | @patch("torchx.runner.events.record") | 
| 62 | 116 | class LogEventTest(unittest.TestCase): | 
|  | 
0 commit comments