Skip to content

Commit

Permalink
update vta code
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jul 14, 2021
1 parent fe626f3 commit 2128bd4
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions vta/python/vta/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,21 +495,21 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
env = get_env()
_ = pad_value
if dst.scope == "global":
if dst.scope() == "global":
# Store
if pad_before or pad_after:
raise RuntimeError("Do not support copy into DRAM with pad")
if src.scope == env.acc_scope:
if src.scope() == env.acc_scope:
elem_width = env.OUT_WIDTH
elem_bytes = env.OUT_ELEM_BYTES
mem_type = env.dev.MEM_ID_OUT
data_type = "int%d" % env.OUT_WIDTH
task_qid = env.dev.QID_STORE_OUT
else:
raise RuntimeError("Do not support copy %s->dram" % (src.scope))
raise RuntimeError("Do not support copy %s->dram" % (src.scope()))
_check_compact(src)
x_size, y_size, x_stride, offset = _get_2d_pattern(
dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True
dst, elem_width, elem_bytes, data_type, src.scope(), allow_fold=True
)
irb = tvm.tir.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
Expand All @@ -528,27 +528,27 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
)
)
return irb.get()
elif src.scope == "global":
if dst.scope == env.acc_scope:
elif src.scope() == "global":
if dst.scope() == env.acc_scope:
elem_width = env.ACC_WIDTH
elem_bytes = env.ACC_ELEM_BYTES
mem_type = env.dev.MEM_ID_ACC
data_type = "int%d" % env.ACC_WIDTH
task_qid = env.dev.QID_LOAD_OUT
elif dst.scope == env.inp_scope:
elif dst.scope() == env.inp_scope:
elem_width = env.INP_WIDTH
elem_bytes = env.INP_ELEM_BYTES
mem_type = env.dev.MEM_ID_INP
data_type = "int%d" % env.INP_WIDTH
task_qid = env.dev.QID_LOAD_INP
elif dst.scope == env.wgt_scope:
elif dst.scope() == env.wgt_scope:
elem_width = env.WGT_WIDTH
elem_bytes = env.WGT_ELEM_BYTES
mem_type = env.dev.MEM_ID_WGT
data_type = "int%d" % env.WGT_WIDTH
task_qid = env.dev.QID_LOAD_WGT
else:
raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
raise RuntimeError("Do not support copy dram->%s" % (dst.scope()))
# collect pad statistics
if pad_before:
assert pad_after
Expand Down Expand Up @@ -586,7 +586,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):

_check_compact(dst)
x_size, y_size, x_stride, offset = _get_2d_pattern(
src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold
src, elem_width, elem_bytes, data_type, dst.scope(), allow_fold=allow_fold
)

if data_type != src.dtype:
Expand Down Expand Up @@ -617,7 +617,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
return irb.get()

else:
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
raise RuntimeError("Do not support copy %s->%s" % (src.scope(), dst.scope()))

return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)

Expand Down

0 comments on commit 2128bd4

Please sign in to comment.