|
8 | 8 | import json |
9 | 9 | import logging |
10 | 10 | import unittest |
11 | | -from typing import List |
12 | 11 | from unittest.mock import MagicMock, patch |
13 | 12 |
|
14 | 13 | from torchx.runner.events import ( |
15 | 14 | _get_or_create_logger, |
16 | 15 | log_event, |
17 | | - record, |
18 | 16 | SourceType, |
19 | 17 | TorchxEvent, |
20 | 18 | ) |
21 | 19 |
|
22 | | -try: |
23 | | - from torch import monitor |
24 | | - |
25 | | - SKIP_MONITOR: bool = False |
26 | | -except ImportError: |
27 | | - SKIP_MONITOR: bool = True |
28 | | - |
29 | 20 |
|
30 | 21 | class TorchxEventLibTest(unittest.TestCase): |
31 | 22 | def assert_event( |
@@ -66,55 +57,6 @@ def test_event_deser(self) -> None: |
66 | 57 | deser_event = TorchxEvent.deserialize(json_event) |
67 | 58 | self.assert_event(event, deser_event) |
68 | 59 |
|
69 | | - @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") |
70 | | - def test_monitor(self) -> None: |
71 | | - event = TorchxEvent( |
72 | | - session="test_session", |
73 | | - scheduler="test_scheduler", |
74 | | - api="test_api", |
75 | | - source=SourceType.EXTERNAL, |
76 | | - ) |
77 | | - monitor_event = event.to_monitor_event() |
78 | | - self.assertEqual( |
79 | | - monitor_event.data, |
80 | | - { |
81 | | - "session": "test_session", |
82 | | - "scheduler": "test_scheduler", |
83 | | - "api": "test_api", |
84 | | - "source": "EXTERNAL", |
85 | | - }, |
86 | | - ) |
87 | | - self.assertEqual(monitor_event.name, "torch.runner.Event") |
88 | | - |
89 | | - @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") |
90 | | - @patch("torchx.runner.events._get_or_create_logger") |
91 | | - def test_monitor_record(self, get_logging_handler: MagicMock) -> None: |
92 | | - event = TorchxEvent( |
93 | | - session="test_session", |
94 | | - scheduler="test_scheduler", |
95 | | - api="test_api", |
96 | | - source=SourceType.EXTERNAL, |
97 | | - ) |
98 | | - # pyre-fixme[11]: Annotation `Event` is not defined as a type. |
99 | | - events: List[monitor.Event] = [] |
100 | | - |
101 | | - def handler(e: monitor.Event) -> None: |
102 | | - events.append(e) |
103 | | - |
104 | | - # pyre-fixme[16]: Module `monitor` has no attribute `register_event_handler`. |
105 | | - handle = monitor.register_event_handler(handler) |
106 | | - |
107 | | - try: |
108 | | - record(event) |
109 | | - finally: |
110 | | - # pyre-fixme[16]: Module `monitor` has no attribute |
111 | | - # `unregister_event_handler`. |
112 | | - monitor.unregister_event_handler(handle) |
113 | | - |
114 | | - self.assertEqual(get_logging_handler.call_count, 1) |
115 | | - self.assertEqual(len(events), 1) |
116 | | - self.assertEqual(events[0].data["session"], "test_session") |
117 | | - |
118 | 60 |
|
119 | 61 | @patch("torchx.runner.events.record") |
120 | 62 | class LogEventTest(unittest.TestCase): |
|
0 commit comments