From 9b2ba910f9c36c2506c242a53fc769a33f5bf380 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Tue, 9 Nov 2021 20:16:02 -0500 Subject: [PATCH] [MicroTVM][PyTest] Explicitly skip MicroTVM unittests. (#9335) * [MicroTVM][PyTest] Explicitly skip MicroTVM unittests. Refactor unit tests so they will show as skipped if `USE_MICRO=OFF`. * Updates following PR review. - Updated to avoid name shadowing of BaseTestHandler - Updated test_micro_transport to use fixture for setup. Ended up needing to refactor to use pytest instead of unittest, split up test functionality during refactor. --- tests/python/conftest.py | 3 - tests/python/unittest/test_crt.py | 4 +- .../python/unittest/test_micro_project_api.py | 137 ++++++--- tests/python/unittest/test_micro_transport.py | 282 ++++++++++-------- 4 files changed, 253 insertions(+), 173 deletions(-) diff --git a/tests/python/conftest.py b/tests/python/conftest.py index e8042c8f50957..ab3ea4e4ec060 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -37,6 +37,3 @@ # collect_ignore.append("unittest/test_auto_scheduler_measure.py") # exception ignored collect_ignore.append("unittest/test_tir_intrin.py") - -if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON": - collect_ignore.append("unittest/test_micro_transport.py") diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 9450a937a155b..fbf908170938c 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -35,9 +35,6 @@ from tvm.topi.utils import get_const_tuple from tvm.topi.testing import conv2d_nchw_python -pytest.importorskip("tvm.micro.testing") -from tvm.micro.testing import check_tune_log - BUILD = True DEBUG = False @@ -222,6 +219,7 @@ def test_platform_timer(): def test_autotune(): """Verify that autotune works with micro.""" import tvm.relay as relay + from tvm.micro.testing import check_tune_log data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32")) weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32")) diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py index e319318656ef9..1e511c41d73eb 100644 --- a/tests/python/unittest/test_micro_project_api.py +++ b/tests/python/unittest/test_micro_project_api.py @@ -26,45 +26,52 @@ import tvm -pytest.importorskip("tvm.micro") -from tvm.micro import project_api +# Implementing as a fixture so that the tvm.micro import doesn't occur +# until fixture setup time. This is necessary for pytest's collection +# phase to work when USE_MICRO=OFF, while still explicitly listing the +# tests as skipped. +@tvm.testing.fixture +def BaseTestHandler(): + from tvm.micro import project_api + + class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler): + + DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo( + platform_name="platform_name", + is_template=True, + model_library_format_path="./model-library-format-path.sh", + project_options=[ + project_api.server.ProjectOption(name="foo", help="Option foo"), + project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), + ], + ) -class BaseTestHandler(project_api.server.ProjectAPIHandler): - - DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo( - platform_name="platform_name", - is_template=True, - model_library_format_path="./model-library-format-path.sh", - project_options=[ - project_api.server.ProjectOption(name="foo", help="Option foo"), - project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), - ], - ) + def server_info_query(self, tvm_version): + return self.DEFAULT_TEST_SERVER_INFO - def server_info_query(self, tvm_version): - return self.DEFAULT_TEST_SERVER_INFO + def generate_project(self, model_library_format_path, crt_path, project_path, options): + assert False, "generate_project is not implemented for this test" - def generate_project(self, model_library_format_path, crt_path, project_path, options): - assert False, "generate_project is not implemented for this test" + def build(self, options): + assert False, "build is not implemented for this test" - def build(self, options): - assert False, "build is not implemented for this test" + def flash(self, options): + assert False, "flash is not implemented for this test" - def flash(self, options): - assert False, "flash is not implemented for this test" + def open_transport(self, options): + assert False, "open_transport is not implemented for this test" - def open_transport(self, options): - assert False, "open_transport is not implemented for this test" + def close_transport(self, options): + assert False, "open_transport is not implemented for this test" - def close_transport(self, options): - assert False, "open_transport is not implemented for this test" + def read_transport(self, n, timeout_sec): + assert False, "read_transport is not implemented for this test" - def read_transport(self, n, timeout_sec): - assert False, "read_transport is not implemented for this test" + def write_transport(self, data, timeout_sec): + assert False, "write_transport is not implemented for this test" - def write_transport(self, data, timeout_sec): - assert False, "write_transport is not implemented for this test" + return BaseTestHandler_Impl class Transport: @@ -100,6 +107,8 @@ def write(self, data): class ClientServerFixture: def __init__(self, handler): + from tvm.micro import project_api + self.handler = handler self.client_to_server = Transport() self.server_to_client = Transport() @@ -121,7 +130,8 @@ def _process_server_request(self): ), "Server failed to process request" -def test_server_info_query(): +@tvm.testing.requires_micro +def test_server_info_query(BaseTestHandler): fixture = ClientServerFixture(BaseTestHandler()) # Examine reply explicitly because these are the defaults for all derivative test cases. @@ -136,7 +146,10 @@ def test_server_info_query(): ] -def test_server_info_query_wrong_tvm_version(): +@tvm.testing.requires_micro +def test_server_info_query_wrong_tvm_version(BaseTestHandler): + from tvm.micro import project_api + def server_info_query(tvm_version): raise project_api.server.UnsupportedTVMVersionError() @@ -148,7 +161,10 @@ def server_info_query(tvm_version): assert "UnsupportedTVMVersionError" in str(exc_info.value) -def test_server_info_query_wrong_protocol_version(): +@tvm.testing.requires_micro +def test_server_info_query_wrong_protocol_version(BaseTestHandler): + from tvm.micro import project_api + ServerInfoProtocol = collections.namedtuple( "ServerInfoProtocol", list(project_api.server.ServerInfo._fields) + ["protocol_version"] ) @@ -166,7 +182,8 @@ def server_info_query(tvm_version): assert "microTVM API Server supports protocol version 0; want 1" in str(exc_info.value) -def test_base_test_handler(): +@tvm.testing.requires_micro +def test_base_test_handler(BaseTestHandler): """All methods should raise AssertionError on BaseTestHandler.""" fixture = ClientServerFixture(BaseTestHandler()) @@ -180,7 +197,8 @@ def test_base_test_handler(): assert (exc_info.exception) == f"{method} is not implemented for this test" -def test_build(): +@tvm.testing.requires_micro +def test_build(BaseTestHandler): with mock.patch.object(BaseTestHandler, "build", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) fixture.client.build(options={"bar": "baz"}) @@ -188,14 +206,18 @@ def test_build(): fixture.handler.build.assert_called_once_with(options={"bar": "baz"}) -def test_flash(): +@tvm.testing.requires_micro +def test_flash(BaseTestHandler): with mock.patch.object(BaseTestHandler, "flash", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) fixture.client.flash(options={"bar": "baz"}) fixture.handler.flash.assert_called_once_with(options={"bar": "baz"}) -def test_open_transport(): +@tvm.testing.requires_micro +def test_open_transport(BaseTestHandler): + from tvm.micro import project_api + timeouts = project_api.server.TransportTimeouts( session_start_retry_timeout_sec=1.0, session_start_timeout_sec=2.0, @@ -210,14 +232,18 @@ def test_open_transport(): fixture.handler.open_transport.assert_called_once_with({"bar": "baz"}) -def test_close_transport(): +@tvm.testing.requires_micro +def test_close_transport(BaseTestHandler): with mock.patch.object(BaseTestHandler, "close_transport", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) fixture.client.close_transport() fixture.handler.close_transport.assert_called_once_with() -def test_read_transport(): +@tvm.testing.requires_micro +def test_read_transport(BaseTestHandler): + from tvm.micro import project_api + with mock.patch.object(BaseTestHandler, "read_transport", return_value=b"foo\x1b") as patch: fixture = ClientServerFixture(BaseTestHandler()) assert fixture.client.read_transport(128, timeout_sec=5.0) == {"data": b"foo\x1b"} @@ -239,7 +265,10 @@ def test_read_transport(): assert fixture.handler.read_transport.call_count == 3 -def test_write_transport(): +@tvm.testing.requires_micro +def test_write_transport(BaseTestHandler): + from tvm.micro import project_api + with mock.patch.object(BaseTestHandler, "write_transport", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) assert fixture.client.write_transport(b"foo", timeout_sec=5.0) is None @@ -264,7 +293,10 @@ class ProjectAPITestError(Exception): """An error raised in test.""" -def test_method_raises_error(): +@tvm.testing.requires_micro +def test_method_raises_error(BaseTestHandler): + from tvm.micro import project_api + with mock.patch.object( BaseTestHandler, "close_transport", side_effect=ProjectAPITestError ) as patch: @@ -276,7 +308,10 @@ def test_method_raises_error(): assert "ProjectAPITestError" in str(exc_info.value) -def test_method_not_found(): +@tvm.testing.requires_micro +def test_method_not_found(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) with pytest.raises(project_api.server.JSONRPCError) as exc_info: @@ -285,7 +320,10 @@ def test_method_not_found(): assert exc_info.value.code == project_api.server.ErrorCode.METHOD_NOT_FOUND -def test_extra_param(): +@tvm.testing.requires_micro +def test_extra_param(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # test one with has_preprocssing and one without @@ -304,7 +342,10 @@ def test_extra_param(): assert "open_transport: extra parameters: invalid_param_name" in str(exc_info.value) -def test_missing_param(): +@tvm.testing.requires_micro +def test_missing_param(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # test one with has_preprocssing and one without @@ -323,7 +364,10 @@ def test_missing_param(): assert "open_transport: parameter options not given" in str(exc_info.value) -def test_incorrect_param_type(): +@tvm.testing.requires_micro +def test_incorrect_param_type(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # The error message given at the JSON-RPC server level doesn't make sense when preprocessing is @@ -338,7 +382,10 @@ def test_incorrect_param_type(): ) -def test_invalid_request(): +@tvm.testing.requires_micro +def test_invalid_request(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # Invalid JSON does not get a reply. diff --git a/tests/python/unittest/test_micro_transport.py b/tests/python/unittest/test_micro_transport.py index a188e612763f0..2fbfada198e38 100644 --- a/tests/python/unittest/test_micro_transport.py +++ b/tests/python/unittest/test_micro_transport.py @@ -26,11 +26,15 @@ import tvm.testing -@tvm.testing.requires_micro -class TransportLoggerTests(unittest.TestCase): +# Implementing as a fixture so that the tvm.micro import doesn't occur +# until fixture setup time. This is necessary for pytest's collection +# phase to work when USE_MICRO=OFF, while still explicitly listing the +# tests as skipped. +@tvm.testing.fixture +def transport(): import tvm.micro - class TestTransport(tvm.micro.transport.Transport): + class MockTransport_Impl(tvm.micro.transport.Transport): def __init__(self): self.exc = None self.to_return = None @@ -62,125 +66,159 @@ def read(self, n, timeout_sec): def write(self, data, timeout_sec): return self._raise_or_return() - def test_transport_logger(self): - """Tests the TransportLogger class.""" - - logger = logging.getLogger("transport_logger_test") - with self.assertLogs(logger) as test_log: - transport = self.TestTransport() - transport_logger = tvm.micro.transport.TransportLogger("foo", transport, logger=logger) - - transport_logger.open() - assert test_log.records[-1].getMessage() == "foo: opening transport" - - ########### read() tests ########## - - # Normal log, single-line data returned. - transport.to_return = b"data" - transport_logger.read(23, 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61" - " data" - ) - - # Normal log, multi-line data returned. - transport.to_return = b"data" * 6 - transport_logger.read(23, 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: read { 3.00s} 23 B -> [ 24 B]:\n" - "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" - "0010 64 61 74 61 64 61 74 61 datadata" - ) - - # Lack of timeout prints. - transport.to_return = b"data" - transport_logger.read(15, None) - assert test_log.records[-1].getMessage() == ( - "foo: read { None } 15 B -> [ 4 B]: 64 61 74 61" - " data" - ) - - # IoTimeoutError includes the timeout value. - transport.exc = tvm.micro.transport.IoTimeoutError() - with self.assertRaises(tvm.micro.transport.IoTimeoutError): - transport_logger.read(23, 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]" - ) - - # Other exceptions are logged by name. - transport.exc = tvm.micro.transport.TransportClosedError() - with self.assertRaises(tvm.micro.transport.TransportClosedError): - transport_logger.read(8, 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: read { 0.00s} 8 B -> [err: TransportClosedError]" - ) - - # KeyboardInterrupt produces no log record. - before_len = len(test_log.records) - transport.exc = KeyboardInterrupt() - with self.assertRaises(KeyboardInterrupt): - transport_logger.read(8, 0.0) - - assert len(test_log.records) == before_len - - ########### write() tests ########## - - # Normal log, single-line data written. - transport.to_return = 3 - transport_logger.write(b"data", 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 4 B]: 64 61 74 61" - " data" - ) - - # Normal log, multi-line data written. - transport.to_return = 20 - transport_logger.write(b"data" * 6, 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 24 B]:\n" - "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" - "0010 64 61 74 61 64 61 74 61 datadata" - ) - - # Lack of timeout prints. - transport.to_return = 3 - transport_logger.write(b"data", None) - assert test_log.records[-1].getMessage() == ( - "foo: write { None } <- [ 4 B]: 64 61 74 61" - " data" - ) - - # IoTimeoutError includes the timeout value. - transport.exc = tvm.micro.transport.IoTimeoutError() - with self.assertRaises(tvm.micro.transport.IoTimeoutError): - transport_logger.write(b"data", 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]" - ) - - # Other exceptions are logged by name. - transport.exc = tvm.micro.transport.TransportClosedError() - with self.assertRaises(tvm.micro.transport.TransportClosedError): - transport_logger.write(b"data", 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]" - ) - - # KeyboardInterrupt produces no log record. - before_len = len(test_log.records) - transport.exc = KeyboardInterrupt() - with self.assertRaises(KeyboardInterrupt): - transport_logger.write(b"data", 0.0) - - assert len(test_log.records) == before_len - - transport_logger.close() - assert test_log.records[-1].getMessage() == "foo: closing transport" + return MockTransport_Impl() + + +@tvm.testing.fixture +def transport_logger(transport): + logger = logging.getLogger("transport_logger_test") + return tvm.micro.transport.TransportLogger("foo", transport, logger=logger) + + +@tvm.testing.fixture +def get_latest_log(caplog): + def inner(): + return caplog.records[-1].getMessage() + + with caplog.at_level(logging.INFO, "transport_logger_test"): + yield inner + + +@tvm.testing.requires_micro +def test_open(transport_logger, get_latest_log): + transport_logger.open() + assert get_latest_log() == "foo: opening transport" + + +@tvm.testing.requires_micro +def test_close(transport_logger, get_latest_log): + transport_logger.close() + assert get_latest_log() == "foo: closing transport" + + +@tvm.testing.requires_micro +def test_read_normal(transport, transport_logger, get_latest_log): + transport.to_return = b"data" + transport_logger.read(23, 3.0) + assert get_latest_log() == ( + "foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_read_multiline(transport, transport_logger, get_latest_log): + transport.to_return = b"data" * 6 + transport_logger.read(23, 3.0) + assert get_latest_log() == ( + "foo: read { 3.00s} 23 B -> [ 24 B]:\n" + "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" + "0010 64 61 74 61 64 61 74 61 datadata" + ) + + +@tvm.testing.requires_micro +def test_read_no_timeout_prints(transport, transport_logger, get_latest_log): + transport.to_return = b"data" + transport_logger.read(15, None) + assert get_latest_log() == ( + "foo: read { None } 15 B -> [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_read_io_timeout(transport, transport_logger, get_latest_log): + # IoTimeoutError includes the timeout value. + transport.exc = tvm.micro.transport.IoTimeoutError() + with pytest.raises(tvm.micro.transport.IoTimeoutError): + transport_logger.read(23, 0.0) + + assert get_latest_log() == ("foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]") + + +@tvm.testing.requires_micro +def test_read_other_exception(transport, transport_logger, get_latest_log): + # Other exceptions are logged by name. + transport.exc = tvm.micro.transport.TransportClosedError() + with pytest.raises(tvm.micro.transport.TransportClosedError): + transport_logger.read(8, 0.0) + + assert get_latest_log() == ("foo: read { 0.00s} 8 B -> [err: TransportClosedError]") + + +@tvm.testing.requires_micro +def test_read_keyboard_interrupt(transport, transport_logger, get_latest_log): + # KeyboardInterrupt produces no log record. + transport.exc = KeyboardInterrupt() + with pytest.raises(KeyboardInterrupt): + transport_logger.read(8, 0.0) + + with pytest.raises(IndexError): + get_latest_log() + + +@tvm.testing.requires_micro +def test_write_normal(transport, transport_logger, get_latest_log): + transport.to_return = 3 + transport_logger.write(b"data", 3.0) + assert get_latest_log() == ( + "foo: write { 3.00s} <- [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_write_multiline(transport, transport_logger, get_latest_log): + # Normal log, multi-line data written. + transport.to_return = 20 + transport_logger.write(b"data" * 6, 3.0) + assert get_latest_log() == ( + "foo: write { 3.00s} <- [ 24 B]:\n" + "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" + "0010 64 61 74 61 64 61 74 61 datadata" + ) + + +@tvm.testing.requires_micro +def test_write_no_timeout_prints(transport, transport_logger, get_latest_log): + transport.to_return = 3 + transport_logger.write(b"data", None) + assert get_latest_log() == ( + "foo: write { None } <- [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_write_io_timeout(transport, transport_logger, get_latest_log): + # IoTimeoutError includes the timeout value. + transport.exc = tvm.micro.transport.IoTimeoutError() + with pytest.raises(tvm.micro.transport.IoTimeoutError): + transport_logger.write(b"data", 0.0) + + assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]") + + +@tvm.testing.requires_micro +def test_write_other_exception(transport, transport_logger, get_latest_log): + # Other exceptions are logged by name. + transport.exc = tvm.micro.transport.TransportClosedError() + with pytest.raises(tvm.micro.transport.TransportClosedError): + transport_logger.write(b"data", 0.0) + + assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]") + + +@tvm.testing.requires_micro +def test_write_keyboard_interrupt(transport, transport_logger, get_latest_log): + # KeyboardInterrupt produces no log record. + transport.exc = KeyboardInterrupt() + with pytest.raises(KeyboardInterrupt): + transport_logger.write(b"data", 0.0) + + with pytest.raises(IndexError): + get_latest_log() if __name__ == "__main__":