Skip to content

Commit

Permalink
[MicroTVM][PyTest] Explicitly skip MicroTVM unittests. (apache#9335)
Browse files Browse the repository at this point in the history
* [MicroTVM][PyTest] Explicitly skip MicroTVM unittests.

Refactor unit tests so they will show as skipped if `USE_MICRO=OFF`.

* Updates following PR review.

- Updated to avoid name shadowing of BaseTestHandler

- Updated test_micro_transport to use fixture for setup.  Ended up
  needing to refactor to use pytest instead of unittest, split up test
  functionality during refactor.
  • Loading branch information
Lunderberg authored and mehrdadh committed Dec 1, 2021
1 parent 614446d commit 9b2ba91
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 173 deletions.
3 changes: 0 additions & 3 deletions tests/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,3 @@
# collect_ignore.append("unittest/test_auto_scheduler_measure.py") # exception ignored

collect_ignore.append("unittest/test_tir_intrin.py")

if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
collect_ignore.append("unittest/test_micro_transport.py")
4 changes: 1 addition & 3 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
from tvm.topi.utils import get_const_tuple
from tvm.topi.testing import conv2d_nchw_python

pytest.importorskip("tvm.micro.testing")
from tvm.micro.testing import check_tune_log

BUILD = True
DEBUG = False

Expand Down Expand Up @@ -222,6 +219,7 @@ def test_platform_timer():
def test_autotune():
"""Verify that autotune works with micro."""
import tvm.relay as relay
from tvm.micro.testing import check_tune_log

data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32"))
weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32"))
Expand Down
137 changes: 92 additions & 45 deletions tests/python/unittest/test_micro_project_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,52 @@

import tvm

pytest.importorskip("tvm.micro")
from tvm.micro import project_api

# Implementing as a fixture so that the tvm.micro import doesn't occur
# until fixture setup time. This is necessary for pytest's collection
# phase to work when USE_MICRO=OFF, while still explicitly listing the
# tests as skipped.
@tvm.testing.fixture
def BaseTestHandler():
from tvm.micro import project_api

class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler):

DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo(
platform_name="platform_name",
is_template=True,
model_library_format_path="./model-library-format-path.sh",
project_options=[
project_api.server.ProjectOption(name="foo", help="Option foo"),
project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"),
],
)

class BaseTestHandler(project_api.server.ProjectAPIHandler):

DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo(
platform_name="platform_name",
is_template=True,
model_library_format_path="./model-library-format-path.sh",
project_options=[
project_api.server.ProjectOption(name="foo", help="Option foo"),
project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"),
],
)
def server_info_query(self, tvm_version):
return self.DEFAULT_TEST_SERVER_INFO

def server_info_query(self, tvm_version):
return self.DEFAULT_TEST_SERVER_INFO
def generate_project(self, model_library_format_path, crt_path, project_path, options):
assert False, "generate_project is not implemented for this test"

def generate_project(self, model_library_format_path, crt_path, project_path, options):
assert False, "generate_project is not implemented for this test"
def build(self, options):
assert False, "build is not implemented for this test"

def build(self, options):
assert False, "build is not implemented for this test"
def flash(self, options):
assert False, "flash is not implemented for this test"

def flash(self, options):
assert False, "flash is not implemented for this test"
def open_transport(self, options):
assert False, "open_transport is not implemented for this test"

def open_transport(self, options):
assert False, "open_transport is not implemented for this test"
def close_transport(self, options):
assert False, "open_transport is not implemented for this test"

def close_transport(self, options):
assert False, "open_transport is not implemented for this test"
def read_transport(self, n, timeout_sec):
assert False, "read_transport is not implemented for this test"

def read_transport(self, n, timeout_sec):
assert False, "read_transport is not implemented for this test"
def write_transport(self, data, timeout_sec):
assert False, "write_transport is not implemented for this test"

def write_transport(self, data, timeout_sec):
assert False, "write_transport is not implemented for this test"
return BaseTestHandler_Impl


class Transport:
Expand Down Expand Up @@ -100,6 +107,8 @@ def write(self, data):

class ClientServerFixture:
def __init__(self, handler):
from tvm.micro import project_api

self.handler = handler
self.client_to_server = Transport()
self.server_to_client = Transport()
Expand All @@ -121,7 +130,8 @@ def _process_server_request(self):
), "Server failed to process request"


def test_server_info_query():
@tvm.testing.requires_micro
def test_server_info_query(BaseTestHandler):
fixture = ClientServerFixture(BaseTestHandler())

# Examine reply explicitly because these are the defaults for all derivative test cases.
Expand All @@ -136,7 +146,10 @@ def test_server_info_query():
]


def test_server_info_query_wrong_tvm_version():
@tvm.testing.requires_micro
def test_server_info_query_wrong_tvm_version(BaseTestHandler):
from tvm.micro import project_api

def server_info_query(tvm_version):
raise project_api.server.UnsupportedTVMVersionError()

Expand All @@ -148,7 +161,10 @@ def server_info_query(tvm_version):
assert "UnsupportedTVMVersionError" in str(exc_info.value)


def test_server_info_query_wrong_protocol_version():
@tvm.testing.requires_micro
def test_server_info_query_wrong_protocol_version(BaseTestHandler):
from tvm.micro import project_api

ServerInfoProtocol = collections.namedtuple(
"ServerInfoProtocol", list(project_api.server.ServerInfo._fields) + ["protocol_version"]
)
Expand All @@ -166,7 +182,8 @@ def server_info_query(tvm_version):
assert "microTVM API Server supports protocol version 0; want 1" in str(exc_info.value)


def test_base_test_handler():
@tvm.testing.requires_micro
def test_base_test_handler(BaseTestHandler):
"""All methods should raise AssertionError on BaseTestHandler."""
fixture = ClientServerFixture(BaseTestHandler())

Expand All @@ -180,22 +197,27 @@ def test_base_test_handler():
assert (exc_info.exception) == f"{method} is not implemented for this test"


def test_build():
@tvm.testing.requires_micro
def test_build(BaseTestHandler):
with mock.patch.object(BaseTestHandler, "build", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
fixture.client.build(options={"bar": "baz"})

fixture.handler.build.assert_called_once_with(options={"bar": "baz"})


def test_flash():
@tvm.testing.requires_micro
def test_flash(BaseTestHandler):
with mock.patch.object(BaseTestHandler, "flash", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
fixture.client.flash(options={"bar": "baz"})
fixture.handler.flash.assert_called_once_with(options={"bar": "baz"})


def test_open_transport():
@tvm.testing.requires_micro
def test_open_transport(BaseTestHandler):
from tvm.micro import project_api

timeouts = project_api.server.TransportTimeouts(
session_start_retry_timeout_sec=1.0,
session_start_timeout_sec=2.0,
Expand All @@ -210,14 +232,18 @@ def test_open_transport():
fixture.handler.open_transport.assert_called_once_with({"bar": "baz"})


def test_close_transport():
@tvm.testing.requires_micro
def test_close_transport(BaseTestHandler):
with mock.patch.object(BaseTestHandler, "close_transport", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
fixture.client.close_transport()
fixture.handler.close_transport.assert_called_once_with()


def test_read_transport():
@tvm.testing.requires_micro
def test_read_transport(BaseTestHandler):
from tvm.micro import project_api

with mock.patch.object(BaseTestHandler, "read_transport", return_value=b"foo\x1b") as patch:
fixture = ClientServerFixture(BaseTestHandler())
assert fixture.client.read_transport(128, timeout_sec=5.0) == {"data": b"foo\x1b"}
Expand All @@ -239,7 +265,10 @@ def test_read_transport():
assert fixture.handler.read_transport.call_count == 3


def test_write_transport():
@tvm.testing.requires_micro
def test_write_transport(BaseTestHandler):
from tvm.micro import project_api

with mock.patch.object(BaseTestHandler, "write_transport", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
assert fixture.client.write_transport(b"foo", timeout_sec=5.0) is None
Expand All @@ -264,7 +293,10 @@ class ProjectAPITestError(Exception):
"""An error raised in test."""


def test_method_raises_error():
@tvm.testing.requires_micro
def test_method_raises_error(BaseTestHandler):
from tvm.micro import project_api

with mock.patch.object(
BaseTestHandler, "close_transport", side_effect=ProjectAPITestError
) as patch:
Expand All @@ -276,7 +308,10 @@ def test_method_raises_error():
assert "ProjectAPITestError" in str(exc_info.value)


def test_method_not_found():
@tvm.testing.requires_micro
def test_method_not_found(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

with pytest.raises(project_api.server.JSONRPCError) as exc_info:
Expand All @@ -285,7 +320,10 @@ def test_method_not_found():
assert exc_info.value.code == project_api.server.ErrorCode.METHOD_NOT_FOUND


def test_extra_param():
@tvm.testing.requires_micro
def test_extra_param(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# test one with has_preprocssing and one without
Expand All @@ -304,7 +342,10 @@ def test_extra_param():
assert "open_transport: extra parameters: invalid_param_name" in str(exc_info.value)


def test_missing_param():
@tvm.testing.requires_micro
def test_missing_param(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# test one with has_preprocssing and one without
Expand All @@ -323,7 +364,10 @@ def test_missing_param():
assert "open_transport: parameter options not given" in str(exc_info.value)


def test_incorrect_param_type():
@tvm.testing.requires_micro
def test_incorrect_param_type(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# The error message given at the JSON-RPC server level doesn't make sense when preprocessing is
Expand All @@ -338,7 +382,10 @@ def test_incorrect_param_type():
)


def test_invalid_request():
@tvm.testing.requires_micro
def test_invalid_request(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# Invalid JSON does not get a reply.
Expand Down
Loading

0 comments on commit 9b2ba91

Please sign in to comment.