Skip to content

Commit

Permalink
Updates following PR review.
Browse files Browse the repository at this point in the history
- Updated to avoid name shadowing of BaseTestHandler

- Updated test_micro_transport to use fixture for setup.  Ended up
  needing to refactor to use pytest instead of unittest, split up test
  functionality during refactor.
  • Loading branch information
Lunderberg committed Oct 22, 2021
1 parent ef64bca commit 9211c71
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 125 deletions.
8 changes: 6 additions & 2 deletions tests/python/unittest/test_micro_project_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
import tvm


# Implementing as a fixture so that the tvm.micro import doesn't occur
# until fixture setup time. This is necessary for pytest's collection
# phase to work when USE_MICRO=OFF, while still explicitly listing the
# tests as skipped.
@tvm.testing.fixture
def BaseTestHandler():
from tvm.micro import project_api

class BaseTestHandler(project_api.server.ProjectAPIHandler):
class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler):

DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo(
platform_name="platform_name",
Expand Down Expand Up @@ -67,7 +71,7 @@ def read_transport(self, n, timeout_sec):
def write_transport(self, data, timeout_sec):
assert False, "write_transport is not implemented for this test"

return BaseTestHandler
return BaseTestHandler_Impl


class Transport:
Expand Down
280 changes: 157 additions & 123 deletions tests/python/unittest/test_micro_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@
import tvm.testing


def test_transport_class():
# Implementing as a fixture so that the tvm.micro import doesn't occur
# until fixture setup time. This is necessary for pytest's collection
# phase to work when USE_MICRO=OFF, while still explicitly listing the
# tests as skipped.
@tvm.testing.fixture
def transport():
import tvm.micro

class TestTransport(tvm.micro.transport.Transport):
class MockTransport_Impl(tvm.micro.transport.Transport):
def __init__(self):
self.exc = None
self.to_return = None
Expand Down Expand Up @@ -61,130 +66,159 @@ def read(self, n, timeout_sec):
def write(self, data, timeout_sec):
return self._raise_or_return()

return TestTransport
return MockTransport_Impl()


@tvm.testing.fixture
def transport_logger(transport):
logger = logging.getLogger("transport_logger_test")
return tvm.micro.transport.TransportLogger("foo", transport, logger=logger)


@tvm.testing.fixture
def get_latest_log(caplog):
def inner():
return caplog.records[-1].getMessage()

with caplog.at_level(logging.INFO, "transport_logger_test"):
yield inner


@tvm.testing.requires_micro
def test_open(transport_logger, get_latest_log):
transport_logger.open()
assert get_latest_log() == "foo: opening transport"


@tvm.testing.requires_micro
def test_close(transport_logger, get_latest_log):
transport_logger.close()
assert get_latest_log() == "foo: closing transport"


@tvm.testing.requires_micro
def test_read_normal(transport, transport_logger, get_latest_log):
transport.to_return = b"data"
transport_logger.read(23, 3.0)
assert get_latest_log() == (
"foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61"
" data"
)


@tvm.testing.requires_micro
def test_read_multiline(transport, transport_logger, get_latest_log):
transport.to_return = b"data" * 6
transport_logger.read(23, 3.0)
assert get_latest_log() == (
"foo: read { 3.00s} 23 B -> [ 24 B]:\n"
"0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n"
"0010 64 61 74 61 64 61 74 61 datadata"
)


@tvm.testing.requires_micro
def test_read_no_timeout_prints(transport, transport_logger, get_latest_log):
transport.to_return = b"data"
transport_logger.read(15, None)
assert get_latest_log() == (
"foo: read { None } 15 B -> [ 4 B]: 64 61 74 61"
" data"
)


@tvm.testing.requires_micro
def test_read_io_timeout(transport, transport_logger, get_latest_log):
# IoTimeoutError includes the timeout value.
transport.exc = tvm.micro.transport.IoTimeoutError()
with pytest.raises(tvm.micro.transport.IoTimeoutError):
transport_logger.read(23, 0.0)

assert get_latest_log() == ("foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]")


@tvm.testing.requires_micro
def test_read_other_exception(transport, transport_logger, get_latest_log):
# Other exceptions are logged by name.
transport.exc = tvm.micro.transport.TransportClosedError()
with pytest.raises(tvm.micro.transport.TransportClosedError):
transport_logger.read(8, 0.0)

assert get_latest_log() == ("foo: read { 0.00s} 8 B -> [err: TransportClosedError]")


@tvm.testing.requires_micro
def test_read_keyboard_interrupt(transport, transport_logger, get_latest_log):
# KeyboardInterrupt produces no log record.
transport.exc = KeyboardInterrupt()
with pytest.raises(KeyboardInterrupt):
transport_logger.read(8, 0.0)

with pytest.raises(IndexError):
get_latest_log()


@tvm.testing.requires_micro
def test_write_normal(transport, transport_logger, get_latest_log):
transport.to_return = 3
transport_logger.write(b"data", 3.0)
assert get_latest_log() == (
"foo: write { 3.00s} <- [ 4 B]: 64 61 74 61"
" data"
)


@tvm.testing.requires_micro
def test_write_multiline(transport, transport_logger, get_latest_log):
# Normal log, multi-line data written.
transport.to_return = 20
transport_logger.write(b"data" * 6, 3.0)
assert get_latest_log() == (
"foo: write { 3.00s} <- [ 24 B]:\n"
"0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n"
"0010 64 61 74 61 64 61 74 61 datadata"
)


@tvm.testing.requires_micro
def test_write_no_timeout_prints(transport, transport_logger, get_latest_log):
transport.to_return = 3
transport_logger.write(b"data", None)
assert get_latest_log() == (
"foo: write { None } <- [ 4 B]: 64 61 74 61"
" data"
)


@tvm.testing.requires_micro
def test_write_io_timeout(transport, transport_logger, get_latest_log):
# IoTimeoutError includes the timeout value.
transport.exc = tvm.micro.transport.IoTimeoutError()
with pytest.raises(tvm.micro.transport.IoTimeoutError):
transport_logger.write(b"data", 0.0)

assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]")


@tvm.testing.requires_micro
def test_write_other_exception(transport, transport_logger, get_latest_log):
# Other exceptions are logged by name.
transport.exc = tvm.micro.transport.TransportClosedError()
with pytest.raises(tvm.micro.transport.TransportClosedError):
transport_logger.write(b"data", 0.0)

assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]")


@tvm.testing.requires_micro
class TransportLoggerTests(unittest.TestCase):
def test_transport_logger(self):
"""Tests the TransportLogger class."""

logger = logging.getLogger("transport_logger_test")
with self.assertLogs(logger) as test_log:
transport = test_transport_class()()
transport_logger = tvm.micro.transport.TransportLogger("foo", transport, logger=logger)

transport_logger.open()
assert test_log.records[-1].getMessage() == "foo: opening transport"

########### read() tests ##########

# Normal log, single-line data returned.
transport.to_return = b"data"
transport_logger.read(23, 3.0)
assert test_log.records[-1].getMessage() == (
"foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61"
" data"
)

# Normal log, multi-line data returned.
transport.to_return = b"data" * 6
transport_logger.read(23, 3.0)
assert test_log.records[-1].getMessage() == (
"foo: read { 3.00s} 23 B -> [ 24 B]:\n"
"0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n"
"0010 64 61 74 61 64 61 74 61 datadata"
)

# Lack of timeout prints.
transport.to_return = b"data"
transport_logger.read(15, None)
assert test_log.records[-1].getMessage() == (
"foo: read { None } 15 B -> [ 4 B]: 64 61 74 61"
" data"
)

# IoTimeoutError includes the timeout value.
transport.exc = tvm.micro.transport.IoTimeoutError()
with self.assertRaises(tvm.micro.transport.IoTimeoutError):
transport_logger.read(23, 0.0)

assert test_log.records[-1].getMessage() == (
"foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]"
)

# Other exceptions are logged by name.
transport.exc = tvm.micro.transport.TransportClosedError()
with self.assertRaises(tvm.micro.transport.TransportClosedError):
transport_logger.read(8, 0.0)

assert test_log.records[-1].getMessage() == (
"foo: read { 0.00s} 8 B -> [err: TransportClosedError]"
)

# KeyboardInterrupt produces no log record.
before_len = len(test_log.records)
transport.exc = KeyboardInterrupt()
with self.assertRaises(KeyboardInterrupt):
transport_logger.read(8, 0.0)

assert len(test_log.records) == before_len

########### write() tests ##########

# Normal log, single-line data written.
transport.to_return = 3
transport_logger.write(b"data", 3.0)
assert test_log.records[-1].getMessage() == (
"foo: write { 3.00s} <- [ 4 B]: 64 61 74 61"
" data"
)

# Normal log, multi-line data written.
transport.to_return = 20
transport_logger.write(b"data" * 6, 3.0)
assert test_log.records[-1].getMessage() == (
"foo: write { 3.00s} <- [ 24 B]:\n"
"0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n"
"0010 64 61 74 61 64 61 74 61 datadata"
)

# Lack of timeout prints.
transport.to_return = 3
transport_logger.write(b"data", None)
assert test_log.records[-1].getMessage() == (
"foo: write { None } <- [ 4 B]: 64 61 74 61"
" data"
)

# IoTimeoutError includes the timeout value.
transport.exc = tvm.micro.transport.IoTimeoutError()
with self.assertRaises(tvm.micro.transport.IoTimeoutError):
transport_logger.write(b"data", 0.0)

assert test_log.records[-1].getMessage() == (
"foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]"
)

# Other exceptions are logged by name.
transport.exc = tvm.micro.transport.TransportClosedError()
with self.assertRaises(tvm.micro.transport.TransportClosedError):
transport_logger.write(b"data", 0.0)

assert test_log.records[-1].getMessage() == (
"foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]"
)

# KeyboardInterrupt produces no log record.
before_len = len(test_log.records)
transport.exc = KeyboardInterrupt()
with self.assertRaises(KeyboardInterrupt):
transport_logger.write(b"data", 0.0)

assert len(test_log.records) == before_len

transport_logger.close()
assert test_log.records[-1].getMessage() == "foo: closing transport"
def test_write_keyboard_interrupt(transport, transport_logger, get_latest_log):
# KeyboardInterrupt produces no log record.
transport.exc = KeyboardInterrupt()
with pytest.raises(KeyboardInterrupt):
transport_logger.write(b"data", 0.0)

with pytest.raises(IndexError):
get_latest_log()


if __name__ == "__main__":
Expand Down

0 comments on commit 9211c71

Please sign in to comment.