Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MicroTVM][PyTest] Explicitly skip MicroTVM unittests. #9335

Merged
merged 2 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tests/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,3 @@
# collect_ignore.append("unittest/test_auto_scheduler_measure.py") # exception ignored

collect_ignore.append("unittest/test_tir_intrin.py")

if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON":
collect_ignore.append("unittest/test_micro_transport.py")
4 changes: 1 addition & 3 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
from tvm.topi.utils import get_const_tuple
from tvm.topi.testing import conv2d_nchw_python

pytest.importorskip("tvm.micro.testing")
from tvm.micro.testing import check_tune_log

BUILD = True
DEBUG = False

Expand Down Expand Up @@ -222,6 +219,7 @@ def test_platform_timer():
def test_autotune():
"""Verify that autotune works with micro."""
import tvm.relay as relay
from tvm.micro.testing import check_tune_log

data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32"))
weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32"))
Expand Down
137 changes: 92 additions & 45 deletions tests/python/unittest/test_micro_project_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,52 @@

import tvm

pytest.importorskip("tvm.micro")
from tvm.micro import project_api

# Implementing as a fixture so that the tvm.micro import doesn't occur
# until fixture setup time. This is necessary for pytest's collection
# phase to work when USE_MICRO=OFF, while still explicitly listing the
# tests as skipped.
@tvm.testing.fixture
def BaseTestHandler():
from tvm.micro import project_api

class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler):

DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo(
platform_name="platform_name",
is_template=True,
model_library_format_path="./model-library-format-path.sh",
project_options=[
project_api.server.ProjectOption(name="foo", help="Option foo"),
project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"),
],
)

class BaseTestHandler(project_api.server.ProjectAPIHandler):

DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo(
platform_name="platform_name",
is_template=True,
model_library_format_path="./model-library-format-path.sh",
project_options=[
project_api.server.ProjectOption(name="foo", help="Option foo"),
project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"),
],
)
def server_info_query(self, tvm_version):
return self.DEFAULT_TEST_SERVER_INFO

def server_info_query(self, tvm_version):
return self.DEFAULT_TEST_SERVER_INFO
def generate_project(self, model_library_format_path, crt_path, project_path, options):
assert False, "generate_project is not implemented for this test"

def generate_project(self, model_library_format_path, crt_path, project_path, options):
assert False, "generate_project is not implemented for this test"
def build(self, options):
assert False, "build is not implemented for this test"

def build(self, options):
assert False, "build is not implemented for this test"
def flash(self, options):
assert False, "flash is not implemented for this test"

def flash(self, options):
assert False, "flash is not implemented for this test"
def open_transport(self, options):
assert False, "open_transport is not implemented for this test"

def open_transport(self, options):
assert False, "open_transport is not implemented for this test"
def close_transport(self, options):
assert False, "open_transport is not implemented for this test"

def close_transport(self, options):
assert False, "open_transport is not implemented for this test"
def read_transport(self, n, timeout_sec):
assert False, "read_transport is not implemented for this test"

def read_transport(self, n, timeout_sec):
assert False, "read_transport is not implemented for this test"
def write_transport(self, data, timeout_sec):
assert False, "write_transport is not implemented for this test"

def write_transport(self, data, timeout_sec):
assert False, "write_transport is not implemented for this test"
return BaseTestHandler_Impl


class Transport:
Expand Down Expand Up @@ -100,6 +107,8 @@ def write(self, data):

class ClientServerFixture:
def __init__(self, handler):
from tvm.micro import project_api

self.handler = handler
self.client_to_server = Transport()
self.server_to_client = Transport()
Expand All @@ -121,7 +130,8 @@ def _process_server_request(self):
), "Server failed to process request"


def test_server_info_query():
@tvm.testing.requires_micro
def test_server_info_query(BaseTestHandler):
fixture = ClientServerFixture(BaseTestHandler())

# Examine reply explicitly because these are the defaults for all derivative test cases.
Expand All @@ -136,7 +146,10 @@ def test_server_info_query():
]


def test_server_info_query_wrong_tvm_version():
@tvm.testing.requires_micro
def test_server_info_query_wrong_tvm_version(BaseTestHandler):
from tvm.micro import project_api

def server_info_query(tvm_version):
raise project_api.server.UnsupportedTVMVersionError()

Expand All @@ -148,7 +161,10 @@ def server_info_query(tvm_version):
assert "UnsupportedTVMVersionError" in str(exc_info.value)


def test_server_info_query_wrong_protocol_version():
@tvm.testing.requires_micro
def test_server_info_query_wrong_protocol_version(BaseTestHandler):
from tvm.micro import project_api

ServerInfoProtocol = collections.namedtuple(
"ServerInfoProtocol", list(project_api.server.ServerInfo._fields) + ["protocol_version"]
)
Expand All @@ -166,7 +182,8 @@ def server_info_query(tvm_version):
assert "microTVM API Server supports protocol version 0; want 1" in str(exc_info.value)


def test_base_test_handler():
@tvm.testing.requires_micro
def test_base_test_handler(BaseTestHandler):
"""All methods should raise AssertionError on BaseTestHandler."""
fixture = ClientServerFixture(BaseTestHandler())

Expand All @@ -180,22 +197,27 @@ def test_base_test_handler():
assert (exc_info.exception) == f"{method} is not implemented for this test"


def test_build():
@tvm.testing.requires_micro
def test_build(BaseTestHandler):
with mock.patch.object(BaseTestHandler, "build", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
fixture.client.build(options={"bar": "baz"})

fixture.handler.build.assert_called_once_with(options={"bar": "baz"})


def test_flash():
@tvm.testing.requires_micro
def test_flash(BaseTestHandler):
with mock.patch.object(BaseTestHandler, "flash", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
fixture.client.flash(options={"bar": "baz"})
fixture.handler.flash.assert_called_once_with(options={"bar": "baz"})


def test_open_transport():
@tvm.testing.requires_micro
def test_open_transport(BaseTestHandler):
from tvm.micro import project_api

timeouts = project_api.server.TransportTimeouts(
session_start_retry_timeout_sec=1.0,
session_start_timeout_sec=2.0,
Expand All @@ -210,14 +232,18 @@ def test_open_transport():
fixture.handler.open_transport.assert_called_once_with({"bar": "baz"})


def test_close_transport():
@tvm.testing.requires_micro
def test_close_transport(BaseTestHandler):
with mock.patch.object(BaseTestHandler, "close_transport", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
fixture.client.close_transport()
fixture.handler.close_transport.assert_called_once_with()


def test_read_transport():
@tvm.testing.requires_micro
def test_read_transport(BaseTestHandler):
from tvm.micro import project_api

with mock.patch.object(BaseTestHandler, "read_transport", return_value=b"foo\x1b") as patch:
fixture = ClientServerFixture(BaseTestHandler())
assert fixture.client.read_transport(128, timeout_sec=5.0) == {"data": b"foo\x1b"}
Expand All @@ -239,7 +265,10 @@ def test_read_transport():
assert fixture.handler.read_transport.call_count == 3


def test_write_transport():
@tvm.testing.requires_micro
def test_write_transport(BaseTestHandler):
from tvm.micro import project_api

with mock.patch.object(BaseTestHandler, "write_transport", return_value=None) as patch:
fixture = ClientServerFixture(BaseTestHandler())
assert fixture.client.write_transport(b"foo", timeout_sec=5.0) is None
Expand All @@ -264,7 +293,10 @@ class ProjectAPITestError(Exception):
"""An error raised in test."""


def test_method_raises_error():
@tvm.testing.requires_micro
def test_method_raises_error(BaseTestHandler):
from tvm.micro import project_api

with mock.patch.object(
BaseTestHandler, "close_transport", side_effect=ProjectAPITestError
) as patch:
Expand All @@ -276,7 +308,10 @@ def test_method_raises_error():
assert "ProjectAPITestError" in str(exc_info.value)


def test_method_not_found():
@tvm.testing.requires_micro
def test_method_not_found(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

with pytest.raises(project_api.server.JSONRPCError) as exc_info:
Expand All @@ -285,7 +320,10 @@ def test_method_not_found():
assert exc_info.value.code == project_api.server.ErrorCode.METHOD_NOT_FOUND


def test_extra_param():
@tvm.testing.requires_micro
def test_extra_param(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# test one with has_preprocssing and one without
Expand All @@ -304,7 +342,10 @@ def test_extra_param():
assert "open_transport: extra parameters: invalid_param_name" in str(exc_info.value)


def test_missing_param():
@tvm.testing.requires_micro
def test_missing_param(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# test one with has_preprocssing and one without
Expand All @@ -323,7 +364,10 @@ def test_missing_param():
assert "open_transport: parameter options not given" in str(exc_info.value)


def test_incorrect_param_type():
@tvm.testing.requires_micro
def test_incorrect_param_type(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# The error message given at the JSON-RPC server level doesn't make sense when preprocessing is
Expand All @@ -338,7 +382,10 @@ def test_incorrect_param_type():
)


def test_invalid_request():
@tvm.testing.requires_micro
def test_invalid_request(BaseTestHandler):
from tvm.micro import project_api

fixture = ClientServerFixture(BaseTestHandler())

# Invalid JSON does not get a reply.
Expand Down
Loading