Skip to content

Commit

Permalink
Lower cache_read and cache_write to Hexagon DMA via tensorize (apache…
Browse files Browse the repository at this point in the history
…#10365)

* Lower cache_read and cache_write to Hexagon DMA via tensorize

* rework test to be compatible with launcher

* remove cpu device api mem_copy implementation and test
  • Loading branch information
adstraw authored and pfk-beta committed Apr 11, 2022
1 parent ca99420 commit 924f438
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 1 deletion.
7 changes: 7 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,13 @@ TVM_DLL const Op& texture2d_store();
*/
TVM_DLL const Op& texture2d_load();

/*!
* \brief Copy 1d memory from source to destination
* Same semantics as memcpy(destination, source, size)
* Allows for device specific implementations e.g. direct memory access (DMA)
*/
TVM_DLL const Op& mem_copy();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/hexagon/hexagon/hexagon_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace tvm {
namespace runtime {
namespace hexagon {

int hexagon_user_dma_1d_sync(void* src, void* dst, uint32_t length);
int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length);

struct Allocation {
Allocation(size_t allocation_nbytes, size_t alignment)
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ namespace tvm {
namespace runtime {
namespace hexagon {

int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length);

HexagonDeviceAPIv2* HexagonDeviceAPIv2::Global() {
static auto* inst = new HexagonDeviceAPIv2();
return inst;
Expand Down Expand Up @@ -149,6 +151,16 @@ void HexagonDeviceAPIv2::CopyDataFromTo(const void* from, size_t from_offset, vo
memcpy(static_cast<char*>(to) + to_offset, static_cast<const char*>(from) + from_offset, size);
}

TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
void* dst = args[0];
void* src = args[1];
int size = args[2];

hexagon_user_dma_1d_sync(dst, src, size);

*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = HexagonDeviceAPIv2::Global();
*rv = static_cast<void*>(ptr);
Expand Down
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
.set_attr<TVectorizable>("TVectorizable", true)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace builtin
} // namespace tir
} // namespace tvm
16 changes: 16 additions & 0 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,26 @@ class BuiltinLower : public StmtExprMutator {
return MakeArray(op);
} else if (op->op.same_as(builtin::tvm_context_id())) {
return make_zero(op->dtype);
} else if (op->op.same_as(builtin::mem_copy())) {
return MakeMemCopy(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}

PrimExpr MakeMemCopy(const CallNode* op) {
PrimExpr dst = op->args[0];
PrimExpr src = op->args[1];
PrimExpr size = op->args[2];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".mem_copy"), dst, src, size});
return VisitExpr(call_packed);
}

// call shape
PrimExpr MakeShape(const CallNode* op) {
// if args.size() == 0, it represents a scalar shape ()
Expand Down
135 changes: 135 additions & 0 deletions tests/python/contrib/test_hexagon/test_cache_read_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest
import numpy as np

import tvm.testing
from tvm import te
from tvm.contrib import utils
from tvm.contrib.hexagon.build import HexagonLauncher
import tvm.contrib.hexagon.hexagon as hexagon

from .conftest import requires_hexagon_toolchain


def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
assert len(shape) == 1
src = te.placeholder(shape=shape, dtype=dtype, name="src")
dst = te.compute(shape, lambda i: src[i], name="dst")
size = shape[0] * np.dtype(dtype).itemsize

src_buffer = tvm.tir.decl_buffer(
shape,
dtype,
scope=src_scope,
offset_factor=1,
)

dst_buffer = tvm.tir.decl_buffer(
shape,
dtype,
scope=dst_scope,
offset_factor=1,
)

def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()

_src = ins[0]
_dst = outs[0]
ib.emit(
tvm.tir.call_intrin(
"handle", "tir.mem_copy", _dst.access_ptr("w"), _src.access_ptr("r"), size
)
)
return ib.get()

return te.decl_tensor_intrin(dst.op, intrin_func, binds={src: src_buffer, dst: dst_buffer})


@requires_hexagon_toolchain
def test_cache_read_write(android_serial_number, tvm_tracker_host, tvm_tracker_port):
size = 128
outer_shape = (size,)
factor = 16
inner_shape = (factor,)
dtype = "int8"

x = te.placeholder(shape=outer_shape, dtype=dtype, name="x")
y = te.placeholder(shape=outer_shape, dtype=dtype, name="y")
z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
s = te.create_schedule(z.op)

x_global = s.cache_read(x, "global.vtcm", [z])
y_global = s.cache_read(y, "global.vtcm", [z])
z_global = s.cache_write(z, "global.vtcm")

zouter, zinner = s[z_global].split(z_global.op.axis[0], factor=factor)

s[x_global].compute_at(s[z_global], zouter)
s[y_global].compute_at(s[z_global], zouter)

mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")

(cache_read_x,) = s[x_global].op.axis
s[x_global].tensorize(cache_read_x, mem_copy_read)

(cache_read_y,) = s[y_global].op.axis
s[y_global].tensorize(cache_read_y, mem_copy_read)

mem_copy_write = intrin_mem_copy(outer_shape, dtype, "global", "global.vtcm")

(cache_write_z,) = s[z].op.axis
s[z].tensorize(cache_write_z, mem_copy_write)

print(tvm.lower(s, [x, y, z]))

target_hexagon = tvm.target.hexagon("v68", link_params=True)
func = tvm.build(
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
)
temp = utils.tempdir()
dso_binary = "test_binary.so"
dso_binary_path = temp.relpath(dso_binary)
func.save(dso_binary_path)

if not android_serial_number:
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")

launcher = HexagonLauncher(serial_number=android_serial_number)
launcher.android_run_rpc(rpc_tracker_host=tvm_tracker_host, rpc_tracker_port=tvm_tracker_port)
launcher.hexagon_setup()
remote_kw = {
"host": tvm_tracker_host,
"port": tvm_tracker_port,
"priority": 0,
"timeout": 60,
}
launcher.hexagon_session_setup(remote_kw)
launcher.upload(dso_binary_path, dso_binary)

with launcher.session as sess:
mod = launcher.get_module(dso_binary)
xt = tvm.nd.array(np.random.uniform(size=size).astype(x.dtype), device=sess.device)
yt = tvm.nd.array(np.random.uniform(size=size).astype(y.dtype), device=sess.device)
zt = tvm.nd.array(np.random.uniform(size=size).astype(z.dtype), device=sess.device)
mod["dmacpy"](xt, yt, zt)
launcher.close()

ref = xt.numpy() + yt.numpy()
np.testing.assert_equal(zt.numpy(), ref)

0 comments on commit 924f438

Please sign in to comment.