| 
8 | 8 | import json  | 
9 | 9 | import logging  | 
10 | 10 | import unittest  | 
11 |  | -from typing import List  | 
12 | 11 | from unittest.mock import MagicMock, patch  | 
13 | 12 | 
 
  | 
14 | 13 | from torchx.runner.events import (  | 
15 | 14 |     _get_or_create_logger,  | 
16 | 15 |     log_event,  | 
17 |  | -    record,  | 
18 | 16 |     SourceType,  | 
19 | 17 |     TorchxEvent,  | 
20 | 18 | )  | 
21 | 19 | 
 
  | 
22 |  | -try:  | 
23 |  | -    from torch import monitor  | 
24 |  | - | 
25 |  | -    SKIP_MONITOR: bool = False  | 
26 |  | -except ImportError:  | 
27 |  | -    SKIP_MONITOR: bool = True  | 
28 |  | - | 
29 | 20 | 
 
  | 
30 | 21 | class TorchxEventLibTest(unittest.TestCase):  | 
31 | 22 |     def assert_event(  | 
@@ -66,55 +57,6 @@ def test_event_deser(self) -> None:  | 
66 | 57 |         deser_event = TorchxEvent.deserialize(json_event)  | 
67 | 58 |         self.assert_event(event, deser_event)  | 
68 | 59 | 
 
  | 
69 |  | -    @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available")  | 
70 |  | -    def test_monitor(self) -> None:  | 
71 |  | -        event = TorchxEvent(  | 
72 |  | -            session="test_session",  | 
73 |  | -            scheduler="test_scheduler",  | 
74 |  | -            api="test_api",  | 
75 |  | -            source=SourceType.EXTERNAL,  | 
76 |  | -        )  | 
77 |  | -        monitor_event = event.to_monitor_event()  | 
78 |  | -        self.assertEqual(  | 
79 |  | -            monitor_event.data,  | 
80 |  | -            {  | 
81 |  | -                "session": "test_session",  | 
82 |  | -                "scheduler": "test_scheduler",  | 
83 |  | -                "api": "test_api",  | 
84 |  | -                "source": "EXTERNAL",  | 
85 |  | -            },  | 
86 |  | -        )  | 
87 |  | -        self.assertEqual(monitor_event.name, "torch.runner.Event")  | 
88 |  | - | 
89 |  | -    @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available")  | 
90 |  | -    @patch("torchx.runner.events._get_or_create_logger")  | 
91 |  | -    def test_monitor_record(self, get_logging_handler: MagicMock) -> None:  | 
92 |  | -        event = TorchxEvent(  | 
93 |  | -            session="test_session",  | 
94 |  | -            scheduler="test_scheduler",  | 
95 |  | -            api="test_api",  | 
96 |  | -            source=SourceType.EXTERNAL,  | 
97 |  | -        )  | 
98 |  | -        # pyre-fixme[11]: Annotation `Event` is not defined as a type.  | 
99 |  | -        events: List[monitor.Event] = []  | 
100 |  | - | 
101 |  | -        def handler(e: monitor.Event) -> None:  | 
102 |  | -            events.append(e)  | 
103 |  | - | 
104 |  | -        # pyre-fixme[16]: Module `monitor` has no attribute `register_event_handler`.  | 
105 |  | -        handle = monitor.register_event_handler(handler)  | 
106 |  | - | 
107 |  | -        try:  | 
108 |  | -            record(event)  | 
109 |  | -        finally:  | 
110 |  | -            # pyre-fixme[16]: Module `monitor` has no attribute  | 
111 |  | -            #  `unregister_event_handler`.  | 
112 |  | -            monitor.unregister_event_handler(handle)  | 
113 |  | - | 
114 |  | -        self.assertEqual(get_logging_handler.call_count, 1)  | 
115 |  | -        self.assertEqual(len(events), 1)  | 
116 |  | -        self.assertEqual(events[0].data["session"], "test_session")  | 
117 |  | - | 
118 | 60 | 
 
  | 
119 | 61 | @patch("torchx.runner.events.record")  | 
120 | 62 | class LogEventTest(unittest.TestCase):  | 
 | 
0 commit comments