|
8 | 8 | import json
|
9 | 9 | import logging
|
10 | 10 | import unittest
|
11 |
| -from typing import List |
12 |
| -from unittest.mock import MagicMock, patch |
| 11 | +from unittest.mock import patch, MagicMock |
13 | 12 |
|
14 | 13 | from torchx.runner.events import (
|
15 | 14 | _get_or_create_logger,
|
16 | 15 | log_event,
|
17 |
| - record, |
18 |
| - SourceType, |
19 |
| - TorchxEvent, |
20 | 16 | )
|
21 | 17 |
|
22 |
| -try: |
23 |
| - from torch import monitor |
24 |
| - |
25 |
| - SKIP_MONITOR: bool = False |
26 |
| -except ImportError: |
27 |
| - SKIP_MONITOR: bool = True |
28 |
| - |
29 | 18 |
|
30 | 19 | class TorchxEventLibTest(unittest.TestCase):
|
31 | 20 | def assert_event(
|
@@ -66,51 +55,6 @@ def test_event_deser(self) -> None:
|
66 | 55 | deser_event = TorchxEvent.deserialize(json_event)
|
67 | 56 | self.assert_event(event, deser_event)
|
68 | 57 |
|
69 |
| - @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") |
70 |
| - def test_monitor(self) -> None: |
71 |
| - event = TorchxEvent( |
72 |
| - session="test_session", |
73 |
| - scheduler="test_scheduler", |
74 |
| - api="test_api", |
75 |
| - source=SourceType.EXTERNAL, |
76 |
| - ) |
77 |
| - monitor_event = event.to_monitor_event() |
78 |
| - self.assertEqual( |
79 |
| - monitor_event.data, |
80 |
| - { |
81 |
| - "session": "test_session", |
82 |
| - "scheduler": "test_scheduler", |
83 |
| - "api": "test_api", |
84 |
| - "source": "EXTERNAL", |
85 |
| - }, |
86 |
| - ) |
87 |
| - self.assertEqual(monitor_event.name, "torch.runner.Event") |
88 |
| - |
89 |
| - @unittest.skipIf(SKIP_MONITOR, "no torch.monitor available") |
90 |
| - @patch("torchx.runner.events._get_or_create_logger") |
91 |
| - def test_monitor_record(self, get_logging_handler: MagicMock) -> None: |
92 |
| - event = TorchxEvent( |
93 |
| - session="test_session", |
94 |
| - scheduler="test_scheduler", |
95 |
| - api="test_api", |
96 |
| - source=SourceType.EXTERNAL, |
97 |
| - ) |
98 |
| - events: List[monitor.Event] = [] |
99 |
| - |
100 |
| - def handler(e: monitor.Event) -> None: |
101 |
| - events.append(e) |
102 |
| - |
103 |
| - handle = monitor.register_event_handler(handler) |
104 |
| - |
105 |
| - try: |
106 |
| - record(event) |
107 |
| - finally: |
108 |
| - monitor.unregister_event_handler(handle) |
109 |
| - |
110 |
| - self.assertEqual(get_logging_handler.call_count, 1) |
111 |
| - self.assertEqual(len(events), 1) |
112 |
| - self.assertEqual(events[0].data["session"], "test_session") |
113 |
| - |
114 | 58 |
|
115 | 59 | @patch("torchx.runner.events.record")
|
116 | 60 | class LogEventTest(unittest.TestCase):
|
|
0 commit comments