Skip to content

Commit

Permalink
[TVMC] Workspace Pools Parameters
Browse files Browse the repository at this point in the history
Address comments, fix linting. Testing improved.
Change-Id: Iea79329b6b9ec1cbc51e5c293449bf6dd43b00c5
  • Loading branch information
dchauhan-arm committed May 24, 2022
1 parent 4d95fbe commit 2f3ca98
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 40 deletions.
6 changes: 4 additions & 2 deletions python/tvm/driver/tvmc/workspace_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


def generate_workspace_pools_args(parser):
"""Generates arguments for each Workspace Pools's options"""
parser.add_argument(
"--workspace-pools",
help="""The name of the memory pool
Expand Down Expand Up @@ -120,8 +121,8 @@ def _target_attributes_to_pools(attr, targets):
target_attributes[pool_name][matched_targets[0]] = target_value
else:
raise TVMCException(
"""The workspace pool target specification for target
access needs to be the same TVM target as when specifying targets to use."""
"The workspace pool target specification for target access "
"needs to be the same TVM target as when specifying targets to use."
)

return target_attributes
Expand All @@ -132,6 +133,7 @@ def _attribute_to_pools(attr):


def workspace_pools_recombobulate(parsed, targets):
"""Reconstructs the Workspace Pools args and returns a WorkspaceMemoryPool object"""
WORKSPACE_POOL_PARAMS = [
"workspace_pools_size_hint_bytes",
"workspace_pools_target_access",
Expand Down
57 changes: 19 additions & 38 deletions tests/python/driver/tvmc/test_command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,11 @@
import platform
import pytest
import shutil
import argparse

from unittest import mock
from tvm.driver import tvmc
from tvm.target import Target
from pytest_lazyfixture import lazy_fixture
from unittest import mock
import tvm
from tvm.driver.tvmc.main import _main
from tvm.driver.tvmc.workspace_pools import (
generate_workspace_pools_args,
workspace_pools_recombobulate,
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -164,37 +158,24 @@ def test_tvmc_tune_file_check(capsys, invalid_input):
assert captured.err == expected_err, on_assert_error


@mock.patch("tvm.relay.build")
@mock.patch("tvm.driver.tvmc.load")
@mock.patch("tvm.transform.PassContext")
@mock.patch("tvm.relay.build", side_effect=tvm.relay.build)
@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None)
@mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target")
def test_compile_check_workspace_pools(mock_ct, mock_pkg, mock_pc, mock_fe, mock_relay):
mock_codegen = {}
mock_codegen["config_key"] = "relay.ext.mock.options"
mock_codegen["pass_pipeline"] = lambda *args, **kwargs: None

mock_fe.return_value = mock.MagicMock()
mock_ct.return_value = mock_codegen
mock_relay.return_value = mock.MagicMock()

parser = argparse.ArgumentParser()
generate_workspace_pools_args(parser)
parsed, _ = parser.parse_known_args(
[
"--workspace-pools=sram",
"--workspace-pools-target-access=sram:llvm:ro",
]
)

targets = [Target("llvm")]
memory_pools = workspace_pools_recombobulate(parsed, targets)
def test_tvmc_workspace_pools_check(mock_pkg, mock_relay, keras_simple, tmpdir_factory):
pytest.importorskip("tensorflow")
tmpdir = tmpdir_factory.mktemp("data")

tvmc_model = tvmc.load("no_file_needed")
tvmc.compile(
tvmc_model,
target="mockcodegen -testopt=value, llvm",
workspace_pools=memory_pools,
# Test model compilation
package_path = os.path.join(tmpdir, "keras-tvm.tar")
compile_str = (
f"tvmc compile --target=llvm --workspace-pools=sram "
f"--workspace-pools-target-access=sram:llvm:rw "
f"--output={package_path} {keras_simple}"
)

compile_args = compile_str.split(" ")[1:]
_main(compile_args)
assert os.path.exists(package_path)
assert mock_relay.call_count == 1
assert mock_relay.call_args_list[0][1]["workspace_memory_pools"].pools[0].pool_name == "sram"
assert (
len(mock_relay.call_args_list[0][1]["workspace_memory_pools"].pools[0].target_access) == 1
)
24 changes: 24 additions & 0 deletions tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import pytest

import tvm
from tvm.ir.memory_pools import PoolInfo, WorkspaceMemoryPools
from tvm.target import Target
import tvm.testing
from tvm.testing.utils import ethosn_available
from tvm.relay.backend import Runtime, Executor
Expand Down Expand Up @@ -680,5 +682,27 @@ def test_compile_tflite_module_with_mod_name_and_ethosu(
assert b"tvmgen_classify_ethos_u_main_" in content


@mock.patch("tvm.relay.build")
@mock.patch("tvm.driver.tvmc.load")
@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None)
def test_compile_check_workspace_pools(mock_pkg, mock_fe, mock_relay):
mock_fe.return_value = mock.MagicMock()
mock_relay.return_value = mock.MagicMock()

memory_pools = WorkspaceMemoryPools(
[PoolInfo(pool_name="sram", target_access={Target("llvm"): "rw"})]
)

tvmc_model = tvmc.load("no_file_needed")
tvmc.compile(
tvmc_model,
target="llvm",
workspace_pools=memory_pools,
)

assert mock_relay.call_count == 1
assert mock_relay.call_args_list[0][1]["workspace_memory_pools"] == memory_pools


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 2f3ca98

Please sign in to comment.