Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[random] support random fill #5913

Merged
merged 3 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ set(USE_MKLDNN OFF)
set(USE_OPENMP none)

# Whether use contrib.random in runtime
set(USE_RANDOM OFF)
set(USE_RANDOM ON)
FrozenGene marked this conversation as resolved.
Show resolved Hide resolved

# Whether use NNPack
set(USE_NNPACK OFF)
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,14 @@ inline size_t GetDataSize(const DLTensor& arr) {
return size;
}

/*!
* \brief return the data alignment in the DLTensor
*
* \param arr the input DLTensor
* \return alignment of data in the DLTensor.
*/
size_t GetDataAlignment(const DLTensor& arr);

/*!
* \brief check if a DLTensor is contiguous.
* \param arr The input DLTensor.
Expand Down
54 changes: 54 additions & 0 deletions src/runtime/contrib/random/mt_random_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
* \brief mt19937 random engine
*/
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>

#include <algorithm>
#include <ctime>
#include <random>

#include "../3rdparty/compiler-rt/builtin_fp16.h"

namespace tvm {
namespace contrib {

Expand Down Expand Up @@ -111,6 +115,56 @@ class RandomEngine {
}
}

void RandomFill(DLTensor* data) {
int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) {
size *= data->shape[i];
}

if (data->ctx.device_type == kDLCPU) {
FillData(data, size);
} else {
DLTensor local;
FrozenGene marked this conversation as resolved.
Show resolved Hide resolved
local.shape = data->shape;
local.ndim = data->ndim;
local.dtype = data->dtype;
local.strides = data->strides;
local.byte_offset = data->byte_offset;
local.ctx = {kDLCPU, 0};
local.data = runtime::DeviceAPI::Get(local.ctx)->AllocDataSpace(
{kDLCPU, 0}, runtime::GetDataSize(local), runtime::GetDataAlignment(local), local.dtype);
FillData(&local, size);
runtime::NDArray::CopyFromTo(&local, data);
}
}

private:
void FillData(DLTensor* tensor, int64_t size) {
// Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy
// quantized dtype (uint8 / int8) data non-empty requirement
std::uniform_real_distribution<> dist(1.0, 10.0);
// Use float representation could make us work well on float / int type too.
if (tensor->dtype.bits == 1) {
std::generate_n(static_cast<bool*>(tensor->data), size, [&]() { return dist(rnd_engine_); });
} else if (tensor->dtype.bits == 8) {
std::generate_n(static_cast<uint8_t*>(tensor->data), size,
[&]() { return dist(rnd_engine_); });
} else if (tensor->dtype.bits == 16) {
std::generate_n(static_cast<uint16_t*>(tensor->data), size, [&]() {
return __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(dist(rnd_engine_)));
});
} else if (tensor->dtype.bits == 32) {
std::generate_n(static_cast<float*>(tensor->data), size, [&]() { return dist(rnd_engine_); });
} else if (tensor->dtype.bits == 64) {
std::generate_n(static_cast<double*>(tensor->data), size,
[&]() { return dist(rnd_engine_); });
} else {
LOG(FATAL) << "Doesn't support dtype code " << tensor->dtype.code << " dtype bits "
<< tensor->dtype.bits;
}
}

private:
std::mt19937 rnd_engine_;
unsigned rseed_;
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/contrib/random/random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRe
entry->random_engine.SampleNormal(out, loc, scale);
});

TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, TVMRetValue* ret) {
RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
DLTensor* out = args[0];
entry->random_engine.RandomFill(out);
});

} // namespace contrib
} // namespace tvm
2 changes: 1 addition & 1 deletion src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ inline void VerifyDataType(DLDataType dtype) {
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}

inline size_t GetDataAlignment(const DLTensor& arr) {
size_t GetDataAlignment(const DLTensor& arr) {
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment;
return align;
Expand Down
57 changes: 57 additions & 0 deletions tests/python/contrib/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@
from tvm import te
import numpy as np
from tvm.contrib import random
from tvm import rpc

def enabled_ctx_list():
ctx_list = [('cpu', tvm.cpu(0)),
('gpu', tvm.gpu(0)),
('cl', tvm.opencl(0)),
('metal', tvm.metal(0)),
('rocm', tvm.rocm(0)),
('vulkan', tvm.vulkan(0)),
('vpi', tvm.vpi(0))]
for k, v in ctx_list:
assert tvm.context(k, 0) == v
ctx_list = [x[1] for x in ctx_list if x[1].exist]
return ctx_list

ENABLED_CTX_LIST = enabled_ctx_list()

def test_randint():
m = 1024
Expand Down Expand Up @@ -89,8 +105,49 @@ def verify(target="llvm"):
assert abs(np.std(na) - 4) < 1e-2
verify()

def test_random_fill():
def test_local(ctx, dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
return
np_ones = np.ones((512, 512), dtype=dtype)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, ctx)
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill")
random_fill(value)

assert np.count_nonzero(value.asnumpy()) == 512 * 512

# make sure arithmentic doesn't overflow too
np_values = value.asnumpy()
assert np.isfinite(np_values * np_values + np_values).any()

def test_rpc(dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
return
if not tvm.runtime.enabled("rpc") or not tvm.runtime.enabled("llvm"):
return
np_ones = np.ones((512, 512), dtype=dtype)
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, remote.cpu())
random_fill = remote.get_function("tvm.contrib.random.random_fill")
random_fill(value)

assert np.count_nonzero(value.asnumpy()) == 512 * 512

# make sure arithmentic doesn't overflow too
np_values = value.asnumpy()
assert np.isfinite(np_values * np_values + np_values).any()

for dtype in ["bool", "int8", "uint8", "int16", "uint16", "int32", "int32",
"int64", "uint64", "float16", "float32", "float64"]:
for ctx in ENABLED_CTX_LIST:
test_local(ctx, dtype)
test_rpc(dtype)

if __name__ == "__main__":
test_randint()
test_uniform()
test_normal()
test_random_fill()