Skip to content

Commit

Permalink
[AutoTVM] [TOPI] Support AutoTVM for int4 tensorcore (#7831)
Browse files Browse the repository at this point in the history
* initial

* int4 asnumpy

* remove

* random test

* format

* random

* remove unused import

* change dist range

* add fuse_pack in

* random engine

* reformat

* remove import

* add cuda context

* refactor code
  • Loading branch information
hypercubestart authored May 1, 2021
1 parent ae1f3d4 commit dc1f189
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 38 deletions.
3 changes: 1 addition & 2 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from random import getrandbits
from collections import namedtuple
import tempfile
import numpy as np

import tvm._ffi
import tvm.ir.transform
Expand Down Expand Up @@ -583,7 +582,7 @@ def run_through_rpc(
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
args = [nd.array(np.zeros(x[0], dtype=x[1]), device=dev) for x in build_result.arg_info]
args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,27 @@ def asnumpy(self):
"""
t = DataType(self.dtype)
shape, dtype = self.shape, self.dtype
old_dtype = dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
if dtype == "int4":
dtype = "int8"
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
if old_dtype == "int4":
length = np_arr.size
np_arr_ret = np.empty((length,), dtype="int8")
np_arr = np_arr.reshape((length,))
old_index = np.bitwise_and(np_arr, 0x0F)
even_index = np.bitwise_and(np_arr >> 4, 0x0F)
np_arr_ret[1::2] = old_index[0 : length // 2]
np_arr_ret[0::2] = even_index[0 : length // 2]
return np_arr_ret.reshape(shape)
return np_arr

def copyto(self, target):
Expand Down
37 changes: 6 additions & 31 deletions python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp

def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
"""Schedule tensorcore template"""
packed_data, packed_kernel = s[Conv].op.input_tensors
pad_data, packed_kernel = s[Conv].op.input_tensors
ic, kh, kw, ii = s[Conv].op.reduce_axis
pad_data = s[packed_data].op.input_tensors[0]
packed_data = s[pad_data].op.input_tensors[0]

block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
Expand All @@ -196,7 +196,7 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
thread_z = te.thread_axis("threadIdx.z")

# Designate the memory hierarchy
AS = s.cache_read(packed_data, "shared", [Conv])
AS = s.cache_read(pad_data, "shared", [Conv])
WS = s.cache_read(packed_kernel, "shared", [Conv])
AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
Expand Down Expand Up @@ -241,7 +241,6 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16])
cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16])
cfg.define_knob("chunk", [1, 2, 4, 8])
cfg.define_knob("fuse_pack", [0, 1])
cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32])
cfg.define_knob("vector_ws", [1, 8])
cfg.define_knob("vector_as", [1, 8, 16])
Expand All @@ -254,13 +253,8 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
vector_as = cfg["vector_as"].val
vector_ws = cfg["vector_ws"].val
split_block_k_nums = cfg["split_block_k_nums"].val
fuse_pack = cfg["fuse_pack"].val

if not fuse_pack:
s[packed_data].compute_inline()
else:
with Target("cuda"):
schedule_injective_from_existing(s, packed_data)
s[packed_data].compute_inline()

if data_dtype in ["int4", "uint4"]:
wmma_m = wmma_n = 8
Expand Down Expand Up @@ -324,24 +318,13 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
cfg["reorder_inner"].apply(s, ConvF, [ko, kh])
cfg["reorder_inner"].apply(s, ConvF, [ki, kw])

cfg.define_knob("compute_at_AS", [0, 1, 2, 3])
cfg.define_knob("compute_at_WS", [0, 1, 2, 3])
compute_at_AS = cfg["compute_at_AS"].val
compute_at_WS = cfg["compute_at_WS"].val

# Move intermediate computation into each output compute tile
s[AF].compute_at(s[ConvF], kw)
s[WF].compute_at(s[ConvF], kw)

# Schedule for A's share memory
if compute_at_AS == 0:
s[AS].compute_at(s[ConvF], ki)
elif compute_at_AS == 1:
s[AS].compute_at(s[ConvF], kw)
elif compute_at_AS == 2:
s[AS].compute_at(s[ConvF], ko)
else:
s[AS].compute_at(s[ConvF], kh)
s[AS].compute_at(s[ConvF], ko)

_, _, n, _, nn, ii = AS.op.axis
tx, xo = s[AS].split(n, nparts=block_row_warps)
ty, _ = s[AS].split(xo, nparts=block_col_warps)
Expand All @@ -354,14 +337,6 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
s[AS].vectorize(_t)

# Schedule for W's share memory
if compute_at_WS == 0:
s[WS].compute_at(s[ConvF], ki)
elif compute_at_WS == 1:
s[WS].compute_at(s[ConvF], kw)
elif compute_at_WS == 2:
s[WS].compute_at(s[ConvF], ko)
else:
s[WS].compute_at(s[ConvF], kh)
s[WS].compute_at(s[ConvF], kw)
kh, kw, ic, o, ii, oo = WS.op.axis
tx, xo = s[WS].split(o, nparts=block_row_warps)
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/contrib/random/mt_random_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ class RandomEngine {
// Use float representation could make us work well on float / int type too.
if (tensor->dtype.bits == 1) {
std::generate_n(static_cast<bool*>(tensor->data), size, [&]() { return dist(rnd_engine_); });
} else if (tensor->dtype.bits == 4) {
// For uint4/int4 we pack two values into a single byte.
// Thus, to ensure both values are non-zero, we use a distribution of 17 - 30.
std::uniform_real_distribution<> packed_dist(17.0, 30.0);
std::generate_n(reinterpret_cast<uint8_t*>(tensor->data), size,
[&]() { return packed_dist(rnd_engine_); });
} else if (tensor->dtype.bits == 8) {
std::generate_n(static_cast<uint8_t*>(tensor->data), size,
[&]() { return dist(rnd_engine_); });
Expand Down
14 changes: 14 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,20 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
return;
}

if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4 && op->lanes == 8) {
// make_int4x8
const int64_t* p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xF;
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
return;
}

std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
Expand Down
7 changes: 3 additions & 4 deletions tests/python/contrib/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def test_local(dev, dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
return
np_ones = np.ones((512, 512), dtype=dtype)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, dev)
value = tvm.nd.empty((512, 512), dtype, dev)
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill")
random_fill(value)

Expand All @@ -119,10 +118,9 @@ def test_rpc(dtype):
return
if not tvm.testing.device_enabled("rpc") or not tvm.runtime.enabled("llvm"):
return
np_ones = np.ones((512, 512), dtype=dtype)
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, remote.cpu())
value = tvm.nd.empty((512, 512), dtype, remote.cpu())
random_fill = remote.get_function("tvm.contrib.random.random_fill")
random_fill(value)

Expand All @@ -134,6 +132,7 @@ def test_rpc(dtype):

for dtype in [
"bool",
"int4",
"int8",
"uint8",
"int16",
Expand Down
56 changes: 55 additions & 1 deletion tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import tvm.testing
import tvm.topi.testing
from tvm import te, autotvm, topi
from tvm import te, autotvm, topi, relay
from tvm.contrib.pickle_memoize import memoize
from tvm.contrib import nvcc
from tvm.topi.nn.utils import get_pad_tuple
Expand Down Expand Up @@ -136,6 +136,59 @@ def check_target(target):
check_target("cuda")


def verify_feature_length():
np.random.seed(123)
target = "cuda"
ctx = tvm.device(target)

batch_size = 32

input_shape = (32, 512, 7, 7)
kernel_shape = (512, 512, 3, 3)

def get_mod():
x = relay.var("x", relay.TensorType(input_shape, "float32"))
y = relay.var("y", relay.TensorType(kernel_shape, "float32"))
f = relay.Function(
[x, y], relay.nn.conv2d(x, y, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3])
)
mod = tvm.IRModule()
mod["main"] = f
mod = relay.transform.InferType()(mod)
return mod, {}

mod, params = get_mod()
layout_config = relay.transform.LayoutConfig()
desired_layouts = {"nn.conv2d": ["HWNC", "default"]}
with layout_config:
seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
mod = relay.transform.recast(mod, "int4", "int32")

tasks = autotvm.task.extract_from_program(
mod, target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
)

assert len(tasks) == 1
task = tasks[0]

space = task.config_space

idx1 = np.random.randint(len(space))
idx2 = np.random.randint(len(space))

cfg = space.get(idx1)
sch, arg_bufs = task.instantiate(cfg)
fea1 = autotvm.feature.get_itervar_feature_flatten(sch, arg_bufs, take_log=True)

cfg = space.get(idx2)
sch, arg_bufs = task.instantiate(cfg)
fea2 = autotvm.feature.get_itervar_feature_flatten(sch, arg_bufs, take_log=True)

assert len(fea1) == len(fea2)


@tvm.testing.requires_tensorcore
def test_conv2d_hwnc_tensorcore():
"""Test the conv2d with tensorcore for hwnc layout"""
Expand All @@ -150,6 +203,7 @@ def test_conv2d_hwnc_tensorcore():
verify_conv2d_hwnc(8, 256, 14, 512, 3, 2, 1)
verify_conv2d_hwnc(8, 256, 14, 512, 1, 2, 0)
verify_conv2d_hwnc(8, 512, 9, 512, 3, 1, 1)
verify_feature_length()


if __name__ == "__main__":
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,27 @@ def check_cuda(n, value, lanes):
check_cuda(64, -3, 2)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_make_int4():
def check_cuda(n, value, lanes):
dtype = "int4"
dev = tvm.gpu(0)
A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, dtype=dtype))
s = te.create_schedule(A.op)
y, x = s[A].op.axis
s[A].vectorize(x)
s[A].bind(y, bx)
fun = tvm.build(s, [A], "cuda", name="make_int4x8")
np_a = np.full((n, lanes), value, dtype="int8")
a = tvm.nd.empty((n, lanes), dtype, dev)
fun(a)
np.testing.assert_equal(a.asnumpy(), np_a)

check_cuda(64, 1, 8)
check_cuda(64, 7, 8)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_inf_nan():
Expand Down Expand Up @@ -972,6 +993,7 @@ def test_unrolled_vectorization():
test_cuda_bf16_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int4()
test_cuda_make_int8()
test_cuda_inf_nan()
test_cuda_shuffle()
Expand Down

0 comments on commit dc1f189

Please sign in to comment.