diff --git a/HalideIR b/HalideIR index 9204453ae8de7..a5a80bdc8232c 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 9204453ae8de77e7dfc32c4d80f58dd788ad75ff +Subproject commit a5a80bdc8232c9dbfe508bb5c46e8f58cdf7ec20 diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index dfb06255381aa..30338cdaaeb88 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -10,6 +10,7 @@ #include #include #include "./c_runtime_api.h" +#include "./serializer.h" namespace tvm { namespace runtime { @@ -105,6 +106,17 @@ class NDArray { */ inline void CopyTo(DLTensor* other); inline void CopyTo(const NDArray& other); + /*! + * \brief Load NDArray from stream + * \param stream The input data stream + * \return Whether load is successful + */ + inline bool Load(dmlc::Stream* stream); + /*! + * \brief Save NDArray to stream + * \param stream The output data stream + */ + inline void Save(dmlc::Stream* stream) const; /*! * \brief Create a NDArray that shares the data memory with the current one. * \param shape The shape of the new array. @@ -161,6 +173,13 @@ class NDArray { friend class TVMArgsSetter; }; +/*! + * \brief Save a DLTensor to stream + * \param strm The outpu stream + * \param tensor The tensor to be saved. + */ +inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); + /*! * \brief Reference counted Container object used to back NDArray. * @@ -280,7 +299,99 @@ inline const DLTensor* NDArray::operator->() const { return &(data_->dl_tensor); } +/*! \brief Magic number for NDArray file */ +constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; + +inline bool SaveDLTensor(dmlc::Stream* strm, + DLTensor* tensor) { + uint64_t header = kTVMNDArrayMagic, reserved = 0; + strm->Write(header); + strm->Write(reserved); + // always save data as CPU context + // so that when loading, it will be loaded as CPU ctx. + DLContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + strm->Write(cpu_ctx); + strm->Write(tensor->ndim); + strm->Write(tensor->dtype); + int ndim = tensor->ndim; + strm->WriteArray(tensor->shape, ndim); + int type_bytes = tensor->dtype.bits / 8; + int64_t num_elems = 1; + for (int i = 0; i < ndim; ++i) { + num_elems *= tensor->shape[i]; + } + int64_t data_byte_size = type_bytes * num_elems; + strm->Write(data_byte_size); + + if (DMLC_IO_NO_ENDIAN_SWAP && + tensor->ctx.device_type == kDLCPU && + tensor->strides == nullptr && + tensor->byte_offset == 0) { + // quick path + strm->Write(tensor->data, data_byte_size); + } else { + std::vector bytes(data_byte_size); + CHECK_EQ(TVMArrayCopyToBytes( + tensor, dmlc::BeginPtr(bytes), data_byte_size), 0) + << TVMGetLastError(); + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); + } + strm->Write(dmlc::BeginPtr(bytes), data_byte_size); + } + return true; +} + +inline void NDArray::Save(dmlc::Stream* strm) const { + SaveDLTensor(strm, const_cast(operator->())); +} + +inline bool NDArray::Load(dmlc::Stream* strm) { + uint64_t header, reserved; + CHECK(strm->Read(&header)) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&reserved)) + << "Invalid DLTensor file format"; + CHECK(header == kTVMNDArrayMagic) + << "Invalid DLTensor file format"; + DLContext ctx; + int ndim; + DLDataType dtype; + CHECK(strm->Read(&ctx)) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&ndim)) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&dtype)) + << "Invalid DLTensor file format"; + CHECK_EQ(ctx.device_type, kDLCPU) + << "Invalid DLTensor context: can only save as CPU tensor"; + std::vector shape(ndim); + if (ndim != 0) { + CHECK(strm->ReadArray(&shape[0], ndim)) + << "Invalid DLTensor file format"; + } + NDArray ret = NDArray::Empty(shape, dtype, ctx); + int64_t num_elems = 1; + int elem_bytes = (ret->dtype.bits + 7) / 8; + for (int i = 0; i < ret->ndim; ++i) { + num_elems *= ret->shape[i]; + } + int64_t data_byte_size; + CHECK(strm->Read(&data_byte_size)) + << "Invalid DLTensor file format"; + CHECK(data_byte_size == num_elems * elem_bytes) + << "Invalid DLTensor file format"; + CHECK(strm->Read(ret->data, data_byte_size)) + << "Invalid DLTensor file format"; + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(ret->data, elem_bytes, num_elems); + } + *this = ret; + return true; +} + } // namespace runtime } // namespace tvm - #endif // TVM_RUNTIME_NDARRAY_H_ diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index 391c7806ad9ca..b2ab5483a22df 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -10,6 +10,7 @@ #include #include #include "./c_runtime_api.h" +#include "./ndarray.h" namespace dmlc { namespace serializer { diff --git a/nnvm/python/nnvm/compiler/param_dict.py b/nnvm/python/nnvm/compiler/param_dict.py index e5db3ce9fc816..78c0e5338e550 100644 --- a/nnvm/python/nnvm/compiler/param_dict.py +++ b/nnvm/python/nnvm/compiler/param_dict.py @@ -59,11 +59,5 @@ def load_param_dict(param_bytes): """ if isinstance(param_bytes, (bytes, str)): param_bytes = bytearray(param_bytes) - load_mod = _load_param_dict(param_bytes) - size = load_mod(0) - param_dict = {} - for i in range(size): - key = load_mod(1, i) - dltensor_handle = ctypes.cast(load_mod(2, i), TVMArrayHandle) - param_dict[key] = tvm.nd.NDArray(dltensor_handle, False) - return param_dict + load_arr = _load_param_dict(param_bytes) + return {v.name : v.array for v in load_arr} diff --git a/nnvm/src/compiler/graph_runtime.cc b/nnvm/src/compiler/graph_runtime.cc index 689ed70ce9f20..e623192258de3 100644 --- a/nnvm/src/compiler/graph_runtime.cc +++ b/nnvm/src/compiler/graph_runtime.cc @@ -4,10 +4,6 @@ * \brief Interface code with TVM graph runtime. */ #include -#include -#include -#include -#include #include "./graph_runtime.h" namespace nnvm { @@ -37,81 +33,6 @@ NNVM_REGISTER_OP(tvm_op) return param.num_outputs; }); -bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) { - uint64_t header = kTVMNDArrayMagic, reserved = 0; - strm->Write(header); - strm->Write(reserved); - strm->Write(tensor->ctx); - strm->Write(tensor->ndim); - strm->Write(tensor->dtype); - int ndim = tensor->ndim; - strm->WriteArray(tensor->shape, ndim); - - int type_bytes = tensor->dtype.bits / 8; - int64_t num_elems = 1; - for (int i = 0; i < ndim; ++i) { - num_elems *= tensor->shape[i]; - } - int64_t data_byte_size = type_bytes * num_elems; - strm->Write(data_byte_size); - // handle endianness of data correctly. - if (DMLC_IO_NO_ENDIAN_SWAP) { - strm->Write(tensor->data, data_byte_size); - } else { - uint8_t* dptr = reinterpret_cast(tensor->data); - std::vector bytes(dptr, dptr + data_byte_size); - dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); - strm->Write(dmlc::BeginPtr(bytes), data_byte_size); - } - return true; -} - -DLTensor* LoadDLTensor(dmlc::Stream* strm) { - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&reserved)) - << "Invalid DLTensor file format"; - CHECK(header == kTVMNDArrayMagic) - << "Invalid DLTensor file format"; - DLTensor tensor; - CHECK(strm->Read(&(tensor.ctx))) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&(tensor.ndim))) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&(tensor.dtype))) - << "Invalid DLTensor file format"; - std::vector shape(tensor.ndim); - if (tensor.ndim != 0) { - CHECK(strm->ReadArray(&shape[0], tensor.ndim)) - << "Invalid DLTensor file format"; - } - DLTensor* ret; - CHECK_EQ(TVMArrayAlloc(shape.data(), - tensor.ndim, - tensor.dtype.code, - tensor.dtype.bits, - tensor.dtype.lanes, - static_cast(tensor.ctx.device_type), - tensor.ctx.device_id, - &ret), 0) << TVMGetLastError(); - int64_t num_elems = 1; - int elem_bytes = (ret->dtype.bits + 7) / 8; - for (int i = 0; i < ret->ndim; ++i) { - num_elems *= ret->shape[i]; - } - int64_t data_byte_size; - CHECK(strm->Read(&data_byte_size)) - << "Invalid DLTensor file format"; - CHECK(data_byte_size == num_elems * elem_bytes) - << "Invalid DLTensor file format"; - CHECK(strm->Read(ret->data, data_byte_size)) - << "Invalid DLTensor file format"; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(ret->data, elem_bytes, num_elems); - } - return ret; -} TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -136,7 +57,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") uint64_t sz = static_cast(arrays.size()); fo->Write(sz); for (size_t i = 0; i < sz; ++i) { - SaveDLTensor(fo, arrays[i]); + tvm::runtime::SaveDLTensor(fo, arrays[i]); } } TVMByteArray arr; @@ -149,11 +70,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") .set_body([](TVMArgs args, TVMRetValue *rv) { std::string bytes = args[0]; - std::vector data; std::vector names; dmlc::MemoryStringStream memstrm(&bytes); dmlc::Stream* strm = &memstrm; - uint64_t header, reserved; CHECK(strm->Read(&header)) << "Invalid parameters file format"; @@ -168,23 +87,19 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") size_t size = static_cast(sz); CHECK(size == names.size()) << "Invalid parameters file format"; + tvm::Array ret; for (size_t i = 0; i < size; ++i) { - data.push_back(LoadDLTensor(strm)); + tvm::runtime::NDArray temp; + temp.Load(strm); + std::shared_ptr n + = std::make_shared(); + n->name = std::move(names[i]); + n->array = temp; + ret.push_back(NDArrayWrapper(n)); } - auto packed = [data, names](TVMArgs args, TVMRetValue* rv) { - int code = args[0]; - if (code == 0) { - *rv = static_cast(data.size()); - } else if (code == 1) { - int index = args[1]; - *rv = names[index]; - } else { - CHECK_EQ(code, 2); - int index = args[1]; - *rv = static_cast(data[index]); - } - }; - *rv = PackedFunc(packed); + *rv = ret; }); + +TVM_REGISTER_NODE_TYPE(NDArrayWrapperNode); } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h index 49a2e9bfc9751..958e82aa8015a 100644 --- a/nnvm/src/compiler/graph_runtime.h +++ b/nnvm/src/compiler/graph_runtime.h @@ -7,14 +7,16 @@ #define NNVM_COMPILER_GRAPH_RUNTIME_H_ #include +#include +#include +#include +#include #include #include namespace nnvm { namespace compiler { -/*! \brief Magic number for NDArray file */ -constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; /*! \brief Magic number for NDArray list file */ constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; @@ -32,6 +34,27 @@ struct TVMOpParam : public dmlc::Parameter { } }; + +/*! + * \brief wrapper node container for exchange. + */ +struct NDArrayWrapperNode : public ::tvm::Node { + std::string name; + tvm::runtime::NDArray array; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("array", &array); + } + + static constexpr const char* _type_key = "NDArrayWrapper"; + TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node); +}; + +/*! \brief Defines memory info */ +TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode); + } // namespace compiler } // namespace nnvm + #endif // NNVM_COMPILER_GRAPH_RUNTIME_H_ diff --git a/nnvm/tests/python/compiler/test_param_dict.py b/nnvm/tests/python/compiler/test_param_dict.py index 4fa41a3a58337..a6605123fa0d2 100644 --- a/nnvm/tests/python/compiler/test_param_dict.py +++ b/nnvm/tests/python/compiler/test_param_dict.py @@ -2,6 +2,9 @@ import numpy as np import nnvm.compiler import tvm +import json +import base64 +from tvm._ffi.base import py_str from tvm import rpc from tvm.contrib import util, graph_runtime @@ -20,6 +23,22 @@ def test_save_load(): np.testing.assert_equal(param2["y"].asnumpy(), y) +def test_ndarray_reflection(): + x = np.random.uniform(size=(10, 2)).astype("float32") + xx = tvm.nd.array(x) + xnode = tvm.make.node("NDArrayWrapper", name="xx", array=xx) + xnode2 = tvm.make.node("NDArrayWrapper", name="x2", array=xx) + assert xnode.array.same_as(xx) + json_str = tvm.save_json([xnode, xnode2]) + json_dict = json.loads(json_str) + b64_str = json_dict["b64ndarrays"][0] + decoded = py_str(base64.b64encode(base64.b64decode(b64_str))) + assert b64_str == decoded + xlist = tvm.load_json(json_str) + np.testing.assert_equal(xlist[0].array.asnumpy(), xx.asnumpy()) + assert xlist[1].array == xlist[0].array + + def test_bigendian_rpc_param(): """Test big endian rpc when there is a PowerPC RPC server available""" host = os.environ.get("TVM_POWERPC_TEST_HOST", None) @@ -60,5 +79,6 @@ def verify_nnvm(remote, target, shape, dtype): if __name__ == "__main__": + test_ndarray_reflection() test_save_load() test_bigendian_rpc_param() diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 602af3ef858b1..79f3c6033a1fa 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -204,6 +204,7 @@ def _handle_return_func(x): # setup return handle for function type RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module +RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( _handle_return_func, TypeCode.FUNC_HANDLE) C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/node.pxi index 93b3fb124f0a1..a563af5237f90 100644 --- a/python/tvm/_ffi/_cython/node.pxi +++ b/python/tvm/_ffi/_cython/node.pxi @@ -23,6 +23,8 @@ cdef inline object make_ret_node(void* chandle): obj = cls(None) else: obj = NodeBase(None) + else: + obj = NodeBase(None) (obj).chandle = chandle return obj diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 135701a803c05..e22950f6eabfd 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -134,6 +134,21 @@ def context(self): """context of this array""" return self.ctx + def __hash__(self): + return ctypes.cast(self.handle, ctypes.c_void_p).value + + def __eq__(self, other): + return self.same_as(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """check object identity equality""" + if not isinstance(other, NDArrayBase): + return False + return self.__hash__() == other.__hash__() + def __setitem__(self, in_slice, value): """Set ndarray value""" if (not isinstance(in_slice, slice) or diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 4e247ed2bf4c2..80d7c3163e10c 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -32,7 +32,7 @@ using TVMAPINode = std::shared_ptr; struct APIAttrGetter : public AttrVisitor { std::string skey; TVMRetValue* ret; - bool found_node_ref{false}; + bool found_ref_object{false}; void Visit(const char* key, double* value) final { if (skey == key) *ret = value[0]; @@ -63,7 +63,13 @@ struct APIAttrGetter : public AttrVisitor { void Visit(const char* key, NodeRef* value) final { if (skey == key) { *ret = value[0]; - found_node_ref = true; + found_ref_object = true; + } + } + void Visit(const char* key, runtime::NDArray* value) final { + if (skey == key) { + *ret = value[0]; + found_ref_object = true; } } }; @@ -98,6 +104,9 @@ struct APIAttrDir : public AttrVisitor { void Visit(const char* key, NodeRef* value) final { names->push_back(key); } + void Visit(const char* key, runtime::NDArray* value) final { + names->push_back(key); + } }; class DSLAPIImpl : public DSLAPI { @@ -130,7 +139,7 @@ class DSLAPIImpl : public DSLAPI { *ret_success = 1; } else { (*tnode)->VisitAttrs(&getter); - *ret_success = getter.found_node_ref || rv.type_code() != kNull; + *ret_success = getter.found_ref_object || rv.type_code() != kNull; if (rv.type_code() == kStr || rv.type_code() == kTVMType) { TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); diff --git a/src/common/base64.h b/src/common/base64.h new file mode 100644 index 0000000000000..31b02d3ca2a34 --- /dev/null +++ b/src/common/base64.h @@ -0,0 +1,284 @@ +/*! + * Copyright 2018 by Contributors + * + * \file base64.h + * \brief data stream support to input and output from/to base64 stream + * base64 is easier to store and pass as text format in mapreduce + */ +#ifndef TVM_COMMON_BASE64_H_ +#define TVM_COMMON_BASE64_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace common { +/*! \brief namespace of base64 decoding and encoding table */ +namespace base64 { +// decoding table +const char DecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' +}; +// encoding table +static const char EncodeTable[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +} // namespace base64 + +/*! + * \brief Buffer reader from stream to avoid + * virtual call overhead on each read. + */ +class StreamBufferReader { + public: + explicit StreamBufferReader(size_t buffer_size) { + buffer_.resize(buffer_size); + } + /*! + * \brief set input stream + * \param stream The stream to be set + */ + void set_stream(dmlc::Stream *stream) { + stream_ = stream; + read_len_ = read_ptr_ = 1; + } + /*! + * \return allows quick read using get char + */ + char GetChar() { + while (true) { + if (read_ptr_ < read_len_) { + return buffer_[read_ptr_++]; + } else { + read_len_ = stream_->Read(&buffer_[0], buffer_.length()); + if (read_len_ == 0) return EOF; + read_ptr_ = 0; + } + } + } + /*! \return whether we are reaching the end of file */ + bool AtEnd() const { + return read_len_ == 0; + } + + private: + /*! \brief the underlying stream */ + dmlc::Stream *stream_{nullptr}; + /*! \brief buffer to hold data */ + std::string buffer_; + /*! \brief length of valid data in buffer */ + size_t read_len_{1}; + /*! \brief pointer in the buffer */ + size_t read_ptr_{1}; +}; + +/*! + * \brief Input stream from base64 encoding + */ +class Base64InStream: public dmlc::Stream { + public: + explicit Base64InStream(dmlc::Stream *fs) : reader_(256) { + reader_.set_stream(fs); + } + /*! + * \brief initialize the stream position to beginning of next base64 stream + * \note call this function before actually start read + */ + void InitPosition(void) { + // get a character + do { + temp_ch_ = reader_.GetChar(); + } while (isspace(temp_ch_)); + } + /*! \brief whether current position is end of a base64 stream */ + bool IsEOF(void) const { + return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); + } + // override read function. + virtual size_t Read(void *ptr, size_t size) { + using base64::DecodeTable; + if (size == 0) return 0; + // use tlen to record left size + size_t tlen = size; + unsigned char *cptr = static_cast(ptr); + // if anything left, load from previous buffered result + if (num_prev_ != 0) { + if (num_prev_ == 2) { + if (tlen >= 2) { + *cptr++ = buf_prev[0]; + *cptr++ = buf_prev[1]; + tlen -= 2; + num_prev_ = 0; + } else { + // assert tlen == 1 + *cptr++ = buf_prev[0]; --tlen; + buf_prev[0] = buf_prev[1]; + num_prev_ = 1; + } + } else { + // assert num_prev_ == 1 + *cptr++ = buf_prev[0]; --tlen; num_prev_ = 0; + } + } + if (tlen == 0) return size; + int nvalue; + // note: everything goes with 4 bytes in Base64 + // so we process 4 bytes a unit + while (tlen && temp_ch_ != EOF && !isspace(temp_ch_)) { + // first byte + nvalue = DecodeTable[temp_ch_] << 18; + { + // second byte + temp_ch_ = reader_.GetChar(); + CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; + nvalue |= DecodeTable[temp_ch_] << 12; + *cptr++ = (nvalue >> 16) & 0xFF; --tlen; + } + { + // third byte + temp_ch_ = reader_.GetChar(); + CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; + // handle termination + if (temp_ch_ == '=') { + temp_ch_ = reader_.GetChar(); + CHECK(temp_ch_ == '=') << "invalid base64 format"; + temp_ch_ = reader_.GetChar(); + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) + << "invalid base64 format"; + break; + } + nvalue |= DecodeTable[temp_ch_] << 6; + if (tlen) { + *cptr++ = (nvalue >> 8) & 0xFF; --tlen; + } else { + buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF; + } + } + { + // fourth byte + temp_ch_ = reader_.GetChar(); + CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) + << "invalid base64 format"; + if (temp_ch_ == '=') { + temp_ch_ = reader_.GetChar(); + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) + << "invalid base64 format"; + break; + } + nvalue |= DecodeTable[temp_ch_]; + if (tlen) { + *cptr++ = nvalue & 0xFF; --tlen; + } else { + buf_prev[num_prev_ ++] = nvalue & 0xFF; + } + } + // get next char + temp_ch_ = reader_.GetChar(); + } + if (kStrictCheck) { + CHECK_EQ(tlen, 0) << "Base64InStream: read incomplete"; + } + return size - tlen; + } + virtual void Write(const void *ptr, size_t size) { + LOG(FATAL) << "Base64InStream do not support write"; + } + + private: + // internal reader + StreamBufferReader reader_; + int temp_ch_{0}; + int num_prev_{0}; + unsigned char buf_prev[2]; + // whether we need to do strict check + static const bool kStrictCheck = false; +}; + +/*! + * \brief Stream to write to base64 format. + */ +class Base64OutStream: public dmlc::Stream { + public: + explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) { + } + virtual void Write(const void *ptr, size_t size) { + using base64::EncodeTable; + size_t tlen = size; + const unsigned char *cptr = static_cast(ptr); + while (tlen) { + while (buf__top_ < 3 && tlen != 0) { + buf_[++buf__top_] = *cptr++; --tlen; + } + if (buf__top_ == 3) { + // flush 4 bytes out + PutChar(EncodeTable[buf_[1] >> 2]); + PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]); + PutChar(EncodeTable[((buf_[2] << 2) | (buf_[3] >> 6)) & 0x3F]); + PutChar(EncodeTable[buf_[3] & 0x3F]); + buf__top_ = 0; + } + } + } + virtual size_t Read(void *ptr, size_t size) { + LOG(FATAL) << "Base64OutStream do not support read"; + return 0; + } + /*! + * \brief finish writing of all current base64 stream, do some post processing + * \param endch character to put to end of stream, if it is EOF, then nothing will be appended. + */ + void Finish(char endch = EOF) { + using base64::EncodeTable; + if (buf__top_ == 1) { + PutChar(EncodeTable[buf_[1] >> 2]); + PutChar(EncodeTable[(buf_[1] << 4) & 0x3F]); + PutChar('='); + PutChar('='); + } + if (buf__top_ == 2) { + PutChar(EncodeTable[buf_[1] >> 2]); + PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]); + PutChar(EncodeTable[(buf_[2] << 2) & 0x3F]); + PutChar('='); + } + buf__top_ = 0; + if (endch != EOF) PutChar(endch); + this->Flush(); + } + + private: + static constexpr size_t kBufferSize = 256; + + dmlc::Stream *fp_{nullptr}; + int buf__top_{0}; + unsigned char buf_[4]; + std::string out_buf_; + + + void PutChar(char ch) { + out_buf_ += ch; + if (out_buf_.length() >= kBufferSize) Flush(); + } + void Flush(void) { + if (out_buf_.length() != 0) { + fp_->Write(&out_buf_[0], out_buf_.length()); + out_buf_.clear(); + } + } +}; +} // namespace common +} // namespace tvm +#endif // TVM_COMMON_BASE64_H_ diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index d0d8570067ebe..7c4e862f0abbd 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -7,8 +7,11 @@ #include #include #include +#include #include +#include #include +#include "../common/base64.h" namespace dmlc { DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); @@ -23,6 +26,7 @@ inline std::string Type2String(const Type& t) { return os.str(); } + inline Type String2Type(std::string s) { std::istringstream is(s); halideir_type_code_t code = Type::Int; @@ -52,6 +56,8 @@ class NodeIndexer : public AttrVisitor { public: std::unordered_map node_index{{nullptr, 0}}; std::vector node_list{nullptr}; + std::unordered_map tensor_index; + std::vector tensor_list; void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} @@ -64,7 +70,13 @@ class NodeIndexer : public AttrVisitor { void Visit(const char* key, NodeRef* value) final { MakeIndex(value->node_.get()); } - + void Visit(const char* key, runtime::NDArray* value) final { + DLTensor* ptr = const_cast((*value).operator->()); + if (tensor_index.count(ptr)) return; + CHECK_EQ(tensor_index.size(), tensor_list.size()); + tensor_index[ptr] = tensor_list.size(); + tensor_list.push_back(ptr); + } // make index of all the children of node void MakeIndex(Node* node) { if (node == nullptr) return; @@ -140,6 +152,7 @@ struct JSONNode { class JSONAttrGetter : public AttrVisitor { public: const std::unordered_map* node_index_; + const std::unordered_map* tensor_index_; JSONNode* node_; void Visit(const char* key, double* value) final { @@ -170,6 +183,10 @@ class JSONAttrGetter : public AttrVisitor { node_->attrs[key] = std::to_string( node_index_->at(value->node_.get())); } + void Visit(const char* key, runtime::NDArray* value) final { + node_->attrs[key] = std::to_string( + tensor_index_->at(const_cast((*value).operator->()))); + } // Get the node void Get(Node* node) { if (node == nullptr) { @@ -209,6 +226,7 @@ class JSONAttrGetter : public AttrVisitor { class JSONAttrSetter : public AttrVisitor { public: const std::vector >* node_list_; + const std::vector* tensor_list_; JSONNode* node_; std::string GetValue(const char* key) const { @@ -254,10 +272,16 @@ class JSONAttrSetter : public AttrVisitor { void Visit(const char* key, NodeRef* value) final { size_t index; ParseValue(key, &index); + CHECK_LE(index, node_list_->size()); value->node_ = node_list_->at(index); } - - // Get the node + void Visit(const char* key, runtime::NDArray* value) final { + size_t index; + ParseValue(key, &index); + CHECK_LE(index, tensor_list_->size()); + *value = tensor_list_->at(index); + } + // set node to be current JSONNode void Set(Node* node) { if (node == nullptr) return; if (node->is_type()) { @@ -292,6 +316,8 @@ struct JSONGraph { size_t root; // the nodes of the graph std::vector nodes; + // base64 b64ndarrays of arrays + std::vector b64ndarrays; // global attributes AttrMap attrs; @@ -299,6 +325,7 @@ struct JSONGraph { writer->BeginObject(); writer->WriteObjectKeyValue("root", root); writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays); if (attrs.size() != 0) { writer->WriteObjectKeyValue("attrs", attrs); } @@ -310,6 +337,7 @@ struct JSONGraph { dmlc::JSONObjectReadHelper helper; helper.DeclareField("root", &root); helper.DeclareField("nodes", &nodes); + helper.DeclareOptionalField("b64ndarrays", &b64ndarrays); helper.DeclareOptionalField("attrs", &attrs); helper.ReadAllFields(reader); } @@ -320,6 +348,7 @@ struct JSONGraph { indexer.MakeIndex(root.node_.get()); JSONAttrGetter getter; getter.node_index_ = &indexer.node_index; + getter.tensor_index_ = &indexer.tensor_index; for (Node* n : indexer.node_list) { JSONNode jnode; getter.node_ = &jnode; @@ -328,6 +357,15 @@ struct JSONGraph { } g.attrs["tvm_version"] = TVM_VERSION; g.root = indexer.node_index.at(root.node_.get()); + // serialize tensor + for (DLTensor* tensor : indexer.tensor_list) { + std::string blob; + dmlc::MemoryStringStream mstrm(&blob); + common::Base64OutStream b64strm(&mstrm); + runtime::SaveDLTensor(&b64strm, tensor); + b64strm.Finish(); + g.b64ndarrays.emplace_back(std::move(blob)); + } return g; } }; @@ -347,6 +385,16 @@ std::shared_ptr LoadJSON_(std::string json_str) { // load in json graph. jgraph.Load(&reader); std::vector > nodes; + std::vector tensors; + // load in tensors + for (const std::string& blob : jgraph.b64ndarrays) { + dmlc::MemoryStringStream mstrm(const_cast(&blob)); + common::Base64InStream b64strm(&mstrm); + b64strm.InitPosition(); + runtime::NDArray temp; + CHECK(temp.Load(&b64strm)); + tensors.emplace_back(temp); + } // node 0 is always null nodes.reserve(jgraph.nodes.size()); for (const JSONNode& jnode : jgraph.nodes) { @@ -362,6 +410,7 @@ std::shared_ptr LoadJSON_(std::string json_str) { CHECK_EQ(nodes.size(), jgraph.nodes.size()); JSONAttrSetter setter; setter.node_list_ = &nodes; + setter.tensor_list_ = &tensors; for (size_t i = 0; i < nodes.size(); ++i) { setter.node_ = &jgraph.nodes[i]; @@ -402,6 +451,9 @@ class NodeAttrSetter : public AttrVisitor { void Visit(const char* key, NodeRef* value) final { *value = GetAttr(key).operator NodeRef(); } + void Visit(const char* key, runtime::NDArray* value) final { + *value = GetAttr(key).operator runtime::NDArray(); + } private: runtime::TVMArgValue GetAttr(const char* key) { diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 89d5e7a28258a..7a75771af23b4 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -4,7 +4,7 @@ */ #include #include -#include +#include #include #include #include @@ -399,52 +399,9 @@ class GraphRuntime : public ModuleNode { void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { // always use strm->Read to maintain endianness conversion - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&reserved)) - << "Invalid DLTensor file format"; - CHECK(header == kTVMNDArrayMagic) - << "Invalid DLTensor file format"; - - DLTensor tensor; - CHECK(strm->Read(&(tensor.ctx))) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&(tensor.ndim))) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&(tensor.dtype))) - << "Invalid DLTensor file format"; - std::vector shape(tensor.ndim); - if (tensor.ndim != 0) { - CHECK(strm->ReadArray(&shape[0], tensor.ndim)) - << "Invalid DLTensor file format"; - } - CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch"; - CHECK(tensor.dtype.bits == dst->dtype.bits && - tensor.dtype.code == dst->dtype.code && - tensor.dtype.lanes == dst->dtype.lanes) << "param type mismatch"; - for (int i = 0; i < tensor.ndim; ++i) { - CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch"; - } - size_t bits = dst->dtype.bits * dst->dtype.lanes; - size_t elem_bytes = (bits + 7) / 8; - size_t num_elems = 1; - for (int i = 0; i < dst->ndim; ++i) { - num_elems *= dst->shape[i]; - } - uint64_t data_byte_size; - CHECK(strm->Read(&data_byte_size)) - << "Invalid DLTensor file format"; - CHECK_EQ(data_byte_size, elem_bytes * num_elems) - << "Invalid DLTensor file format"; - std::vector bytes(data_byte_size + 1); - CHECK(strm->Read(&bytes[0], data_byte_size)) - << "Invalid DLTensor file format"; - // explicitly swap endian when necessary. - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems); - } - TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size)); + NDArray temp; + temp.Load(strm); + temp.CopyTo(dst); } void GraphRuntime::LoadParams(dmlc::Stream* strm) { diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 8e2590dc6359b..7ebcf7d30b33f 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -13,8 +13,6 @@ namespace tvm { namespace runtime { -/*! \brief Magic number for NDArray file */ -constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; /*! \brief Magic number for NDArray list file */ constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;