Skip to content

Commit

Permalink
[random] support random fill (apache#5913)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene authored and Trevor Morris committed Sep 2, 2020
1 parent 4257415 commit ae9be5f
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ set(USE_MKLDNN OFF)
set(USE_OPENMP none)

# Whether use contrib.random in runtime
set(USE_RANDOM OFF)
set(USE_RANDOM ON)

# Whether use NNPack
set(USE_NNPACK OFF)
Expand Down
47 changes: 47 additions & 0 deletions src/runtime/contrib/random/mt_random_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
* \brief mt19937 random engine
*/
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>

#include <algorithm>
#include <ctime>
#include <random>

#include "../3rdparty/compiler-rt/builtin_fp16.h"

namespace tvm {
namespace contrib {

Expand Down Expand Up @@ -111,6 +115,49 @@ class RandomEngine {
}
}

void RandomFill(DLTensor* data) {
int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) {
size *= data->shape[i];
}

if (data->ctx.device_type == kDLCPU) {
FillData(data, size);
} else {
runtime::NDArray local = runtime::NDArray::Empty(
std::vector<int64_t>{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0});
FillData(&local.ToDLPack()->dl_tensor, size);
runtime::NDArray::CopyFromTo(&local.ToDLPack()->dl_tensor, data);
}
}

private:
void FillData(DLTensor* tensor, int64_t size) {
// Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy
// quantized dtype (uint8 / int8) data non-empty requirement
std::uniform_real_distribution<> dist(1.0, 10.0);
// Use float representation could make us work well on float / int type too.
if (tensor->dtype.bits == 1) {
std::generate_n(static_cast<bool*>(tensor->data), size, [&]() { return dist(rnd_engine_); });
} else if (tensor->dtype.bits == 8) {
std::generate_n(static_cast<uint8_t*>(tensor->data), size,
[&]() { return dist(rnd_engine_); });
} else if (tensor->dtype.bits == 16) {
std::generate_n(static_cast<uint16_t*>(tensor->data), size, [&]() {
return __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(dist(rnd_engine_)));
});
} else if (tensor->dtype.bits == 32) {
std::generate_n(static_cast<float*>(tensor->data), size, [&]() { return dist(rnd_engine_); });
} else if (tensor->dtype.bits == 64) {
std::generate_n(static_cast<double*>(tensor->data), size,
[&]() { return dist(rnd_engine_); });
} else {
LOG(FATAL) << "Doesn't support dtype code " << tensor->dtype.code << " dtype bits "
<< tensor->dtype.bits;
}
}

private:
std::mt19937 rnd_engine_;
unsigned rseed_;
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/contrib/random/random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRe
entry->random_engine.SampleNormal(out, loc, scale);
});

TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, TVMRetValue* ret) {
RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
DLTensor* out = args[0];
entry->random_engine.RandomFill(out);
});

} // namespace contrib
} // namespace tvm
57 changes: 57 additions & 0 deletions tests/python/contrib/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@
from tvm import te
import numpy as np
from tvm.contrib import random
from tvm import rpc

def enabled_ctx_list():
ctx_list = [('cpu', tvm.cpu(0)),
('gpu', tvm.gpu(0)),
('cl', tvm.opencl(0)),
('metal', tvm.metal(0)),
('rocm', tvm.rocm(0)),
('vulkan', tvm.vulkan(0)),
('vpi', tvm.vpi(0))]
for k, v in ctx_list:
assert tvm.context(k, 0) == v
ctx_list = [x[1] for x in ctx_list if x[1].exist]
return ctx_list

ENABLED_CTX_LIST = enabled_ctx_list()

def test_randint():
m = 1024
Expand Down Expand Up @@ -89,8 +105,49 @@ def verify(target="llvm"):
assert abs(np.std(na) - 4) < 1e-2
verify()

def test_random_fill():
def test_local(ctx, dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
return
np_ones = np.ones((512, 512), dtype=dtype)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, ctx)
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill")
random_fill(value)

assert np.count_nonzero(value.asnumpy()) == 512 * 512

# make sure arithmentic doesn't overflow too
np_values = value.asnumpy()
assert np.isfinite(np_values * np_values + np_values).any()

def test_rpc(dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
return
if not tvm.runtime.enabled("rpc") or not tvm.runtime.enabled("llvm"):
return
np_ones = np.ones((512, 512), dtype=dtype)
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, remote.cpu())
random_fill = remote.get_function("tvm.contrib.random.random_fill")
random_fill(value)

assert np.count_nonzero(value.asnumpy()) == 512 * 512

# make sure arithmentic doesn't overflow too
np_values = value.asnumpy()
assert np.isfinite(np_values * np_values + np_values).any()

for dtype in ["bool", "int8", "uint8", "int16", "uint16", "int32", "int32",
"int64", "uint64", "float16", "float32", "float64"]:
for ctx in ENABLED_CTX_LIST:
test_local(ctx, dtype)
test_rpc(dtype)

if __name__ == "__main__":
test_randint()
test_uniform()
test_normal()
test_random_fill()

0 comments on commit ae9be5f

Please sign in to comment.