Skip to content

Commit

Permalink
addressed upstream comments-3
Browse files Browse the repository at this point in the history
  • Loading branch information
d-smirnov committed Jun 17, 2022
1 parent db6d324 commit 243b1f9
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 28 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/memory_pools.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ struct ConstantInfoNode : public Object {
hash_reduce(data);
}

static constexpr const char* _type_key = "tir.usmp.ConstantInfo";
static constexpr const char* _type_key = "ir.ConstantInfo";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantInfoNode, Object);
Expand Down
34 changes: 11 additions & 23 deletions python/tvm/ir/memory_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,6 @@ class PoolInfo(Object):
"""

# The string parameter to indicate read and write access to a pool
# This needs to be kept in sync with kTargetPoolReadWriteAccess in
# include/tvm/ir/memory_pools.h
READ_WRITE_ACCESS = "rw"
# The string parameter to indicate read only access to a pool
# This needs to be kept in sync with kTargetPoolReadOnlyAccess in
# include/tvm/ir/memory_pools.h
READ_ONLY_ACCESS = "ro"

def __init__(
self,
pool_name: str,
Expand All @@ -98,18 +89,15 @@ def __init__(
if not target_burst_bytes:
target_burst_bytes = dict()

self.__init_handle_by_constructor__(
_ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member
pool_name,
target_access,
size_hint_bytes,
clock_frequency_hz,
read_bandwidth_bytes_per_cycle,
write_bandwidth_bytes_per_cycle,
read_latency_cycles,
write_latency_cycles,
target_burst_bytes,
)
self.pool_name = pool_name
self.target_access = target_access
self.size_hint_bytes = size_hint_bytes
self.clock_frequency_hz = clock_frequency_hz
self.read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle
self.write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle
self.read_latency_cycles = read_latency_cycles
self.write_latency_cycles = write_latency_cycles
self.target_burst_bytes = target_burst_bytes


@register_object("ir.PoolInfoProperties")
Expand Down Expand Up @@ -196,13 +184,13 @@ class WorkspacePoolInfo(PoolInfo):
The properties of the pool.
"""

# pylint: disable=W0231
def __init__(
self,
pool_name: str,
targets,
pool_info_properties=None,
):
super().__init__(pool_name, targets)
if pool_info_properties is None:
pool_info_properties = PoolInfoProperties()

Expand Down Expand Up @@ -231,14 +219,14 @@ class ConstantPoolInfo(PoolInfo):
The properties of the pool.
"""

# pylint: disable=W0231
def __init__(
self,
pool_name: str,
targets, # list[Target]
constant_info_arr=None, # list[ConstantInfo]
pool_info_properties=None,
):
super().__init__(pool_name, targets)
if constant_info_arr is None:
constant_info_arr = []
if pool_info_properties is None:
Expand Down
2 changes: 1 addition & 1 deletion src/ir/memory_pools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ ConstantInfo::ConstantInfo(String name_hint, Integer byte_offset, runtime::NDArr
}

TVM_REGISTER_NODE_TYPE(ConstantInfoNode);
TVM_REGISTER_GLOBAL("tir.usmp.ConstantInfo")
TVM_REGISTER_GLOBAL("ir.ConstantInfo")
.set_body_typed([](String name_hint, Integer byte_offset, runtime::NDArray data) {
return ConstantInfo(name_hint, byte_offset, data);
});
Expand Down
8 changes: 5 additions & 3 deletions tests/python/relay/aot/test_crt_aot_usmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tvm.relay import transform
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend import Executor, Runtime
from tvm import WorkspaceMemoryPools, WorkspacePoolInfo, ConstantPoolInfo, PoolInfoProperties
from tvm import WorkspaceMemoryPools, WorkspacePoolInfo, PoolInfoProperties
from tvm.micro import model_library_format as mlf
from tvm.micro.testing.aot_test_utils import parametrize_aot_options
from tvm.testing.aot import (
Expand All @@ -51,14 +51,17 @@ def _check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module):
# U1 test case
@parametrize_aot_options
def test_synthetic(interface_api, use_unpacked_api, test_runner):
"""
Simple U1 usecase test
"""
mod, params = tvm.relay.testing.synthetic.get_workload()
main_func = mod["main"]
shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params}
type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params}

input_data = np.ones(shape_dict["data"]).astype(type_dict["data"])
params = {}
for name, shape in shape_dict.items():
for name, _ in shape_dict.items():
if name != "data":
params[name] = np.ones(shape_dict[name]).astype(type_dict[name])

Expand All @@ -73,7 +76,6 @@ def test_synthetic(interface_api, use_unpacked_api, test_runner):
},
)

pass_config = {"tir.usmp.enable": True}
test_runner = AOTTestRunner(
makefile=test_runner.makefile,
prologue=test_runner.prologue,
Expand Down

0 comments on commit 243b1f9

Please sign in to comment.