diff --git a/include/tvm/ir/memory_pools.h b/include/tvm/ir/memory_pools.h index ee07841de412..ebab13cf3adb 100644 --- a/include/tvm/ir/memory_pools.h +++ b/include/tvm/ir/memory_pools.h @@ -65,6 +65,7 @@ struct PoolInfoNode : public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pool_name", &pool_name); + v->Visit("targets", &targets); v->Visit("size_hint_bytes", &size_hint_bytes); v->Visit("clock_frequency_hz", &clock_frequency_hz); v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle); diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 138504470459..2955df55432d 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -26,6 +26,7 @@ from tvm import autotvm, auto_scheduler from tvm import relay from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity +from tvm.ir.memory_pools import WorkspaceMemoryPools from tvm.target import Target from tvm.relay.backend import Executor, Runtime @@ -37,6 +38,7 @@ from .pass_list import parse_pass_list_str from .transform import convert_graph_layout from .shape_parser import parse_shape_string +from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -142,10 +144,11 @@ def add_compile_parser(subparsers, _, json_params): default="default", help="The output module name. Defaults to 'default'.", ) - for one_entry in json_params: parser.set_defaults(**one_entry) + generate_workspace_pools_args(parser) + def drive_compile(args): """Invoke tvmc.compiler module with command line arguments @@ -161,6 +164,7 @@ def drive_compile(args): Zero if successfully completed """ + if not os.path.isfile(args.FILE): raise TVMCException( f"Input file '{args.FILE}' doesn't exist, is a broken symbolic link, or a directory." @@ -170,6 +174,9 @@ def drive_compile(args): dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None + additional_targets = reconstruct_target_args(args) + workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets) + compile_model( tvmc_model, args.target, @@ -186,8 +193,11 @@ def drive_compile(args): desired_layout=args.desired_layout, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, - additional_target_options=reconstruct_target_args(args), mod_name=args.module_name, + additional_target_options=additional_targets, + workspace_pools=( + workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets) + ), ) return 0 @@ -212,6 +222,7 @@ def compile_model( additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, use_vm: bool = False, mod_name: Optional[str] = "default", + workspace_pools: Optional[WorkspaceMemoryPools] = None, ): """Compile a model from a supported framework into a TVM module. @@ -263,6 +274,9 @@ def compile_model( Whether to use the VM to compile the model as opposed to the graph executor mod_name: str, optional The module name + workspace_pools: WorkspaceMemoryPools, optional + Specification of WorkspacePoolInfo objects to be used as workspace memory in the + compilation. Returns ------- @@ -313,6 +327,7 @@ def compile_model( params=params, use_vm=use_vm, mod_name=mod_name, + workspace_pools=workspace_pools, ) else: with autotvm.apply_history_best(tuning_records): @@ -328,6 +343,7 @@ def compile_model( params=params, use_vm=use_vm, mod_name=mod_name, + workspace_pools=workspace_pools, ) else: with tvm.transform.PassContext( @@ -342,6 +358,7 @@ def compile_model( params=params, use_vm=use_vm, mod_name=mod_name, + workspace_pools=workspace_pools, ) # Generate output dump files with sources @@ -380,6 +397,7 @@ def build( params: Dict[str, tvm.nd.NDArray], use_vm: bool, mod_name: str, + workspace_pools: Optional[WorkspaceMemoryPools], ): """ Builds the model with the provided executor. @@ -408,7 +426,13 @@ def build( return relay.vm.compile(mod, target=tvm_target, params=params) logger.debug("building with relay build") return relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params, mod_name=mod_name + mod, + target=tvm_target, + executor=executor, + runtime=runtime, + params=params, + mod_name=mod_name, + workspace_memory_pools=workspace_pools, ) diff --git a/python/tvm/driver/tvmc/workspace_pools.py b/python/tvm/driver/tvmc/workspace_pools.py new file mode 100644 index 000000000000..2c91488fb48b --- /dev/null +++ b/python/tvm/driver/tvmc/workspace_pools.py @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Functions for processing dynamic workspace pool TVMC args +""" + + +import logging +import re + +from tvm.driver.tvmc import TVMCException +from tvm.target import Target +from tvm.ir.memory_pools import PoolInfoProperties, WorkspaceMemoryPools, WorkspacePoolInfo + + +# pylint: disable=invalid-name +logger = logging.getLogger("TVMC") + + +def generate_workspace_pools_args(parser): + """Generates arguments for each Workspace Pools's options""" + parser.add_argument( + "--workspace-pools", + help="""The name of the memory pool + Example usage: --workspace-pools=flash""", + ) + parser.add_argument( + "--workspace-pools-targets", + help="""The name of the targets specified for the memory pool + Example usage: --workspace-pools-targets=flash:llvm""", + action="append", + ) + parser.add_argument( + "--workspace-pools-size-hint-bytes", + nargs="?", + help="""The expected size hint to be used by the allocator. + Example usage: --workspace-pools-size-hint-bytes=flash:8""", + action="append", + ) + parser.add_argument( + "--workspace-pools-clock-frequency-hz", + nargs="?", + help="""The clock frequency that the memory pool runs at in Hz. + Example usage: --workspace-pools-clock-frequency-hz=flash:70000000""", + action="append", + ) + parser.add_argument( + "--workspace-pools-read-bandwidth-bytes-per-cycle", + nargs="?", + help="""The read bandwidth of the memory pool in bytes/cycle. + Example usage: --workspace-pools-read-bandwidth-bytes-per-cycle=flash:4""", + action="append", + ) + parser.add_argument( + "--workspace-pools-write-bandwidth-bytes-per-cycle", + nargs="?", + help="""The write bandwidth of the memory pool in bytes/cycle. + Example usage: --workspace-pools-write-bandwidth-bytes-per-cycle=flash:8""", + action="append", + ) + parser.add_argument( + "--workspace-pools-read-latency-cycles", + nargs="?", + help="""The read latency of the memory pool in cycles. + Example usage: --workspace-pools-read-latency-cycles=flash:4""", + action="append", + ) + parser.add_argument( + "--workspace-pools-write-latency-cycles", + nargs="?", + help="""The write latency of the memory pool in cycles. + Example usage: --workspace-pools-write-latency-cycles=flash:8""", + action="append", + ) + parser.add_argument( + "--workspace-pools-target-burst-bytes", + help="""The burst length of the memory pool in bytes per target. + Example usage: --workspace-pools-target-burst-bytes=flash:accel:1""", + action="append", + ) + + +def _parse_target_burst(attr_str, pool_name): + if pool_name not in attr_str: + return {} + + return {target: int(attr_str[pool_name][target]) for target in attr_str[pool_name]} + + +def _parse_target_string(attr_str, targets, pool_name): + if attr_str is None: + raise TVMCException(f'No target specified for Workspace Pool "{pool_name}"') + + target_name = [re.split(",", attr_str)] + matched_targets = [ + target + for target in targets + if any(target.kind.name in target_string_match for target_string_match in target_name[0]) + ] + if not matched_targets: + raise TVMCException(f'Workspace Pool "{pool_name}" using undefined Target "{target_name}"') + return matched_targets + + +def _split_pools_to_pool_names(attr_str): + return re.split(",", attr_str) if attr_str else [] + + +def _parse_target_attributes_of_pool_name(attr_str, targets): + if not targets or attr_str is None: + return {} + + target_attributes = {} + for pool_values in attr_str: + pool_name, target_name, target_value = re.split(":", pool_values) + if pool_name not in target_attributes: + target_attributes[pool_name] = {} + + matched_targets = [target for target in targets if target_name == target.kind.name] + if matched_targets: + target_attributes[pool_name][matched_targets[0]] = target_value + else: + raise TVMCException( + "The workspace pool target specification " + "needs to contain a subset of the same TVM " + "targets as when specifying targets to use." + ) + return target_attributes + + +def _parse_attribute_of_pool_name(attr_str): + return dict(pool.split(":", maxsplit=1) for pool in attr_str) if attr_str else {} + + +def workspace_pools_recombobulate(parsed, targets, extra_target): + """Reconstructs the Workspace Pools args and returns a WorkspaceMemoryPool object""" + WORKSPACE_POOL_PARAMS = [ + "workspace_pools_size_hint_bytes", + "workspace_pools_targets", + "workspace_pools_clock_frequency_hz", + "workspace_pools_read_bandwidth_bytes_per_cycle", + "workspace_pools_write_bandwidth_bytes_per_cycle", + "workspace_pools_read_latency_cycles", + "workspace_pools_write_latency_cycles", + ] + WORKSPACE_POOL_TARGET_PARAMS = [ + "workspace_pools_target_burst_bytes", + ] + + # Load extra targets from CLI + additional_targets = [] + + for t in extra_target: + additional_targets.append(Target(t["raw"], host=targets[0].host or targets[0])) + + target = targets + additional_targets + if targets[0].host: + target.append(targets[0].host) + + workspace_pools = _split_pools_to_pool_names(parsed.workspace_pools) + if not workspace_pools: + return None + + parse_attribute_to_pool_name = { + workspace_pool_param: _parse_attribute_of_pool_name(getattr(parsed, workspace_pool_param)) + for workspace_pool_param in WORKSPACE_POOL_PARAMS + } + parse_target_burst_bytes_to_pool = { + workspace_pool_param: _parse_target_attributes_of_pool_name( + getattr(parsed, workspace_pool_param), targets + ) + for workspace_pool_param in WORKSPACE_POOL_TARGET_PARAMS + } + + return WorkspaceMemoryPools( + [ + WorkspacePoolInfo( + pool_name, + targets=_parse_target_string( + parse_attribute_to_pool_name["workspace_pools_targets"].get(pool_name), + target, + pool_name, + ), + pool_info_properties=PoolInfoProperties( + size_hint_bytes=int( + parse_attribute_to_pool_name["workspace_pools_size_hint_bytes"].get( + pool_name, -1 + ) + ), + clock_frequency_hz=int( + parse_attribute_to_pool_name["workspace_pools_clock_frequency_hz"].get( + pool_name, -1 + ) + ), + read_bandwidth_bytes_per_cycle=int( + parse_attribute_to_pool_name[ + "workspace_pools_read_bandwidth_bytes_per_cycle" + ].get(pool_name, -1) + ), + write_bandwidth_bytes_per_cycle=int( + parse_attribute_to_pool_name[ + "workspace_pools_write_bandwidth_bytes_per_cycle" + ].get(pool_name, -1) + ), + read_latency_cycles=int( + parse_attribute_to_pool_name["workspace_pools_read_latency_cycles"].get( + pool_name, 0 + ) + ), + write_latency_cycles=int( + parse_attribute_to_pool_name["workspace_pools_write_latency_cycles"].get( + pool_name, 0 + ) + ), + target_burst_bytes=_parse_target_burst( + parse_target_burst_bytes_to_pool["workspace_pools_target_burst_bytes"], + pool_name, + ), + ), + ) + for pool_name in workspace_pools + ] + ) diff --git a/python/tvm/ir/memory_pools.py b/python/tvm/ir/memory_pools.py index 0186a89f8413..553bb49e3c92 100644 --- a/python/tvm/ir/memory_pools.py +++ b/python/tvm/ir/memory_pools.py @@ -189,7 +189,7 @@ class WorkspaceMemoryPools(Object): def __init__( self, - pools: List[PoolInfo], + pools: List[WorkspacePoolInfo], ): self.__init_handle_by_constructor__( _ffi_api.WorkspaceMemoryPools, pools # type: ignore # pylint: disable=no-member diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index fd2f18aa9905..9a238fba3bf5 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -32,7 +32,9 @@ runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU) .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) - .set_attr("TIRToRuntime", TIRToRuntime); + .set_attr("TIRToRuntime", TIRToRuntime) + .add_attr_option>("mattr") + .add_attr_option("mcpu"); } // namespace cmsisnn } // namespace contrib diff --git a/tests/python/driver/tvmc/test_command_line.py b/tests/python/driver/tvmc/test_command_line.py index 0fddb7073f3f..af45f0bb7e00 100644 --- a/tests/python/driver/tvmc/test_command_line.py +++ b/tests/python/driver/tvmc/test_command_line.py @@ -21,6 +21,8 @@ from pytest_lazyfixture import lazy_fixture from unittest import mock + +import tvm from tvm.driver.tvmc.main import _main from tvm.driver.tvmc.model import TVMCException from tvm.driver.tvmc import compiler @@ -159,6 +161,26 @@ def test_tvmc_tune_file_check(capsys, invalid_input): assert captured.err == expected_err, on_assert_error +@mock.patch("tvm.relay.build", side_effect=tvm.relay.build) +@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None) +def test_tvmc_workspace_pools_check(mock_pkg, mock_relay, keras_simple, tmpdir_factory): + pytest.importorskip("tensorflow") + tmpdir = tmpdir_factory.mktemp("data") + + # Test model compilation + package_path = os.path.join(tmpdir, "keras-tvm.tar") + compile_str = ( + f"tvmc compile --target=llvm --workspace-pools=sram " + f"--workspace-pools-targets=sram:llvm " + f"--output={package_path} {keras_simple}" + ) + compile_args = compile_str.split(" ")[1:] + _main(compile_args) + assert os.path.exists(package_path) + assert mock_relay.call_count == 1 + assert mock_relay.call_args_list[0][1]["workspace_memory_pools"].pools[0].pool_name == "sram" + + @pytest.fixture def paddle_model(paddle_resnet50): # If we can't import "paddle" module, skip testing paddle as the input model. diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index e8e93a6c7514..27cd78d436c7 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -24,6 +24,8 @@ import pytest import tvm +from tvm.ir.memory_pools import WorkspacePoolInfo, WorkspaceMemoryPools +from tvm.target import Target import tvm.testing from tvm.relay.op.contrib.ethosn import ethosn_available from tvm.relay.backend import Runtime, Executor @@ -674,5 +676,25 @@ def test_compile_tflite_module_with_mod_name_and_ethosu( assert b"tvmgen_classify_ethos_u_main_" in content +@mock.patch("tvm.relay.build") +@mock.patch("tvm.driver.tvmc.load") +@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None) +def test_compile_check_workspace_pools(mock_pkg, mock_fe, mock_relay): + mock_fe.return_value = mock.MagicMock() + mock_relay.return_value = mock.MagicMock() + memory_pools = WorkspaceMemoryPools( + [WorkspacePoolInfo(pool_name="sram", targets=[Target("llvm")])] + ) + tvmc_model = tvmc.load("no_file_needed") + tvmc.compile( + tvmc_model, + target="llvm,c", + workspace_pools=memory_pools, + ) + + assert mock_relay.call_count == 1 + assert mock_relay.call_args_list[0][1]["workspace_memory_pools"] == memory_pools + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/driver/tvmc/test_workspace_pools.py b/tests/python/driver/tvmc/test_workspace_pools.py new file mode 100644 index 000000000000..386181aaf20b --- /dev/null +++ b/tests/python/driver/tvmc/test_workspace_pools.py @@ -0,0 +1,404 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import argparse + +from tvm.driver.tvmc.workspace_pools import ( + generate_workspace_pools_args, + workspace_pools_recombobulate, +) +from tvm.target import Target +from tvm.driver.tvmc import TVMCException + + +def test_workspace_pools_argparse(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, unparsed = parser.parse_known_args( + [ + "--workspace-pools=sram,flash", + "--workspace-pools-targets=sram:c,llvm", + "--workspace-pools-targets=flash:c", + "--workspace-pools-size-hint-bytes=sram:400", + "--workspace-pools-size-hint-bytes=sram:500", + "--workspace-pools-clock-frequency-hz=sram:500", + "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:200", + "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:100", + "--workspace-pools-read-latency-cycles=sram:50", + "--workspace-pools-read-latency-cycles=flash:30", + "--workspace-pools-write-latency-cycles=sram:9001", + "--workspace-pools-target-burst-bytes=sram:c:2", + "--workspace-pools-is-internal=sram:0", + ] + ) + + assert parsed.workspace_pools == "sram,flash" + assert parsed.workspace_pools_targets == ["sram:c,llvm", "flash:c"] + assert parsed.workspace_pools_size_hint_bytes == ["sram:400", "sram:500"] + assert parsed.workspace_pools_clock_frequency_hz == ["sram:500"] + assert parsed.workspace_pools_read_bandwidth_bytes_per_cycle == ["sram:200"] + assert parsed.workspace_pools_write_bandwidth_bytes_per_cycle == ["sram:100"] + assert parsed.workspace_pools_read_latency_cycles == ["sram:50", "flash:30"] + assert parsed.workspace_pools_write_latency_cycles == ["sram:9001"] + assert parsed.workspace_pools_target_burst_bytes == ["sram:c:2"] + + assert unparsed == ["--workspace-pools-is-internal=sram:0"] + + +def test_workspace_pools_recombobulate_empty(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args([]) + + targets = [Target("llvm")] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + assert memory_pools is None + + +def test_workspace_pools_recombobulate(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + "--workspace-pools-targets=sram:llvm", + "--workspace-pools-size-hint-bytes=sram:400", + "--workspace-pools-clock-frequency-hz=sram:500", + ] + ) + + targets = [Target("llvm")] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + assert len(memory_pools.pools) == 1 + assert memory_pools.pools[0].pool_name == "sram" + assert memory_pools.pools[0].size_hint_bytes == 400 + assert memory_pools.pools[0].clock_frequency_hz == 500 + + +def test_workspace_pools_defaults(): + parser = argparse.ArgumentParser() + targets = [Target("llvm")] + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + "--workspace-pools-targets=sram:llvm", + ] + ) + + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + assert len(memory_pools.pools) == 1 + assert memory_pools.pools[0].pool_name == "sram" + assert memory_pools.pools[0].size_hint_bytes == -1 + assert memory_pools.pools[0].clock_frequency_hz == -1 + assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == -1 + assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == -1 + assert memory_pools.pools[0].read_latency_cycles == 0 + assert memory_pools.pools[0].write_latency_cycles == 0 + assert len(memory_pools.pools[0].target_burst_bytes) == 0 + + +def test_workspace_pools_recombobulate_multi_fields(): + parser = argparse.ArgumentParser() + targets = [Target("c")] + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + "--workspace-pools-targets=sram:c", + "--workspace-pools-size-hint-bytes=sram:400", + "--workspace-pools-clock-frequency-hz=sram:500", + "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:200", + "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:100", + "--workspace-pools-read-latency-cycles=sram:50", + "--workspace-pools-write-latency-cycles=sram:9001", + "--workspace-pools-target-burst-bytes=sram:c:2", + ] + ) + + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + assert len(memory_pools.pools) == 1 + assert memory_pools.pools[0].pool_name == "sram" + assert memory_pools.pools[0].size_hint_bytes == 400 + assert memory_pools.pools[0].clock_frequency_hz == 500 + assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 200 + assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 100 + assert memory_pools.pools[0].read_latency_cycles == 50 + assert memory_pools.pools[0].write_latency_cycles == 9001 + assert len(memory_pools.pools[0].target_burst_bytes) == 1 + assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 2 + + +def test_workspace_pools_recombobulate_multi_fields_variant(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=flash", + "--workspace-pools-targets=flash:c", + "--workspace-pools-size-hint-bytes=flash:2048", + "--workspace-pools-clock-frequency-hz=flash:2000000", + "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4", + "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1", + "--workspace-pools-read-latency-cycles=flash:2000", + "--workspace-pools-write-latency-cycles=flash:1000", + "--workspace-pools-target-burst-bytes=flash:c:4", + ] + ) + + targets = [Target("c")] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + assert len(memory_pools.pools) == 1 + assert memory_pools.pools[0].pool_name == "flash" + assert memory_pools.pools[0].size_hint_bytes == 2048 + assert memory_pools.pools[0].clock_frequency_hz == 2000000 + assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 4 + assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 1 + assert memory_pools.pools[0].read_latency_cycles == 2000 + assert memory_pools.pools[0].write_latency_cycles == 1000 + assert len(memory_pools.pools[0].target_burst_bytes) == 1 + assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 4 + + +def test_workspace_pools_recombobulate_multi_fields_multi_pools(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram,flash", + "--workspace-pools-targets=sram:c", + "--workspace-pools-targets=flash:c", + "--workspace-pools-size-hint-bytes=sram:1024", + "--workspace-pools-size-hint-bytes=flash:2048", + "--workspace-pools-clock-frequency-hz=sram:4000000", + "--workspace-pools-clock-frequency-hz=flash:2000000", + "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:8", + "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4", + "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:4", + "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1", + "--workspace-pools-read-latency-cycles=sram:250", + "--workspace-pools-read-latency-cycles=flash:2000", + "--workspace-pools-write-latency-cycles=sram:500", + "--workspace-pools-write-latency-cycles=flash:1000", + "--workspace-pools-target-burst-bytes=sram:c:8", + "--workspace-pools-target-burst-bytes=flash:c:4", + ] + ) + + targets = [Target("c")] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + assert len(memory_pools.pools) == 2 + + assert memory_pools.pools[0].pool_name == "sram" + assert memory_pools.pools[0].size_hint_bytes == 1024 + assert memory_pools.pools[0].clock_frequency_hz == 4000000 + assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 8 + assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 4 + assert memory_pools.pools[0].read_latency_cycles == 250 + assert memory_pools.pools[0].write_latency_cycles == 500 + assert len(memory_pools.pools[0].target_burst_bytes) == 1 + assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 8 + + assert memory_pools.pools[1].pool_name == "flash" + assert memory_pools.pools[1].size_hint_bytes == 2048 + assert memory_pools.pools[1].clock_frequency_hz == 2000000 + assert memory_pools.pools[1].read_bandwidth_bytes_per_cycle == 4 + assert memory_pools.pools[1].write_bandwidth_bytes_per_cycle == 1 + assert memory_pools.pools[1].read_latency_cycles == 2000 + assert memory_pools.pools[1].write_latency_cycles == 1000 + assert len(memory_pools.pools[1].target_burst_bytes) == 1 + assert memory_pools.pools[1].target_burst_bytes[targets[0]] == 4 + + +def test_workspace_pools_recombobulate_multi_fields_ordering(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram,flash", + "--workspace-pools-targets=flash:c", + "--workspace-pools-targets=sram:c", + "--workspace-pools-size-hint-bytes=flash:2048", + "--workspace-pools-size-hint-bytes=sram:1024", + "--workspace-pools-clock-frequency-hz=sram:4000000", + "--workspace-pools-clock-frequency-hz=flash:2000000", + "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:8", + "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4", + "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:4", + "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1", + "--workspace-pools-read-latency-cycles=sram:250", + "--workspace-pools-read-latency-cycles=flash:2000", + "--workspace-pools-write-latency-cycles=flash:1000", + "--workspace-pools-write-latency-cycles=sram:500", + "--workspace-pools-target-burst-bytes=sram:c:8", + "--workspace-pools-target-burst-bytes=flash:c:4", + ] + ) + + targets = [Target("c")] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + assert len(memory_pools.pools) == 2 + + assert memory_pools.pools[0].pool_name == "sram" + assert memory_pools.pools[0].size_hint_bytes == 1024 + assert memory_pools.pools[0].write_latency_cycles == 500 + + assert memory_pools.pools[1].pool_name == "flash" + assert memory_pools.pools[1].size_hint_bytes == 2048 + assert memory_pools.pools[1].write_latency_cycles == 1000 + + +def test_workspace_pools_recombobulate_multi_target(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + "--workspace-pools-targets=sram:c,llvm", + "--workspace-pools-target-burst-bytes=sram:c:8", + "--workspace-pools-target-burst-bytes=sram:llvm:4", + ] + ) + + c_target = Target("c") + llvm_target = Target("llvm") + extra_targets = [] + + targets = [c_target, llvm_target] + memory_pools = workspace_pools_recombobulate(parsed, targets, extra_targets) + + assert len(memory_pools.pools) == 1 + + assert len(memory_pools.pools[0].target_burst_bytes) == 2 + assert memory_pools.pools[0].target_burst_bytes[c_target] == 8 + assert memory_pools.pools[0].target_burst_bytes[llvm_target] == 4 + + +def test_workspace_pools_recombobulate_no_target_burst_bytes(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + "--workspace-pools-targets=sram:c", + "--workspace-pools-target-burst-bytes=sram:c:8", + ] + ) + + c_target = Target("c") + targets = [c_target] + + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + + assert len(memory_pools.pools) == 1 + assert len(memory_pools.pools[0].target_burst_bytes) == 1 + assert memory_pools.pools[0].target_burst_bytes[c_target] == 8 + + +def test_workspace_pools_recombobulate_missing_target(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + ] + ) + + c_target = Target("c") + with pytest.raises(TVMCException): + workspace_pools_recombobulate(parsed, [c_target], _) + + +def test_workspace_pools_recombobulate_multi_target_multi_pool(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + "--workspace-pools-targets=sram:c,llvm", + "--workspace-pools-target-burst-bytes=sram:c:8", + "--workspace-pools-target-burst-bytes=sram:llvm:4", + ] + ) + + c_target = Target("c") + llvm_target = Target("llvm") + + targets = [c_target, llvm_target] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + + assert len(memory_pools.pools) == 1 + + assert len(memory_pools.pools[0].target_burst_bytes) == 2 + assert memory_pools.pools[0].target_burst_bytes[llvm_target] == 4 + assert memory_pools.pools[0].target_burst_bytes[c_target] == 8 + + +def test_workspace_pools_recombobulate_parameter_overrides(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram", + "--workspace-pools-targets=sram:c", + "--workspace-pools-size-hint-bytes=sram:800", + "--workspace-pools-size-hint-bytes=sram:400", + "--workspace-pools-clock-frequency-hz=sram:4000000", + "--workspace-pools-clock-frequency-hz=sram:3600000", + ] + ) + + c_target = Target("c") + + targets = [c_target] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + + assert len(memory_pools.pools) == 1 + + assert memory_pools.pools[0].size_hint_bytes == 400 + assert memory_pools.pools[0].clock_frequency_hz == 3600000 + + +def test_workspace_pools_recombobulate_single_pool_overrides(): + parser = argparse.ArgumentParser() + generate_workspace_pools_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--workspace-pools=sram,flash", + "--workspace-pools-targets=sram:c", + "--workspace-pools-targets=flash:c", + "--workspace-pools-targets=sram:c,llvm", # Override on one pool + "--workspace-pools-size-hint-bytes=sram:800", + "--workspace-pools-size-hint-bytes=flash:1200", + "--workspace-pools-size-hint-bytes=sram:400", # Override on one pool + ] + ) + + c_target = Target("c") + llvm_target = Target("llvm") + + targets = [c_target, llvm_target] + memory_pools = workspace_pools_recombobulate(parsed, targets, _) + + assert len(memory_pools.pools) == 2 + + assert memory_pools.pools[0].size_hint_bytes == 400 + assert memory_pools.pools[1].size_hint_bytes == 1200 + + assert len(memory_pools.pools[0].targets) == 2 + assert len(memory_pools.pools[1].targets) == 1