diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index b67611a273aa..55fbd1c1bcf4 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -68,6 +68,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API) "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc" ) endif() diff --git a/ffi/include/tvm/ffi/extra/base64.h b/ffi/include/tvm/ffi/extra/base64.h new file mode 100644 index 000000000000..136fec2e7f84 --- /dev/null +++ b/ffi/include/tvm/ffi/extra/base64.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * + * \file tvm/ffi/extra/base64.h + * \brief Base64 encoding and decoding utilities + */ +#ifndef TVM_FFI_EXTRA_BASE64_H_ +#define TVM_FFI_EXTRA_BASE64_H_ + +#include + +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Encode a byte array into a base64 string + * \param bytes The byte array to encode + * \return The base64 encoded string + */ +inline String Base64Encode(TVMFFIByteArray bytes) { + // encoding every 3 bytes into 4 characters + constexpr const char kEncodeTable[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string encoded; + encoded.reserve(4 * (bytes.size + 2) / 3); + + for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { + int32_t buf[3]; + buf[0] = static_cast(bytes.data[i]); + buf[1] = static_cast(bytes.data[i + 1]); + buf[2] = static_cast(bytes.data[i + 2]); + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); + encoded.push_back(kEncodeTable[buf[2] & 0x3F]); + } + if (bytes.size % 3 == 1) { + int32_t buf[1] = {static_cast(bytes.data[bytes.size - 1])}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); + encoded.push_back('='); + encoded.push_back('='); + } else if (bytes.size % 3 == 2) { + int32_t buf[2] = {static_cast(bytes.data[bytes.size - 2]), + static_cast(bytes.data[bytes.size - 1])}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); + encoded.push_back('='); + } + return String(encoded); +} + +/*! + * \brief Encode a bytes object into a base64 string + * \param data The bytes object to encode + * \return The base64 encoded string + */ +inline String Base64Encode(const Bytes& data) { + return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); +} + +/*! + * \brief Decode a base64 string into a byte array + * \param data The base64 encoded string to decode + * \return The decoded byte array + */ +inline Bytes Base64Decode(TVMFFIByteArray bytes) { + constexpr const char kDecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + }; + std::string decoded; + decoded.reserve(bytes.size * 3 / 4); + if (bytes.size == 0) return Bytes(); + TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; + // leverage this property to simplify decoding + static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); + // base64 is always multiple of 4 bytes + for (size_t i = 0; i < bytes.size; i += 4) { + // decode every 4 characters into 24bits, each character contains 6 bits + // note that = is also decoded as 0, which is safe to skip + int32_t buf[4] = { + static_cast(bytes.data[i]), + static_cast(bytes.data[i + 1]), + static_cast(bytes.data[i + 2]), + static_cast(bytes.data[i + 3]), + }; + int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | + (static_cast(kDecodeTable[buf[1]]) << 12) | + (static_cast(kDecodeTable[buf[2]]) << 6) | + static_cast(kDecodeTable[buf[3]]); + // unpack 24bits into 3 bytes, each contains 8 bits + decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); + if (buf[2] != '=') { + decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); + } + if (buf[3] != '=') { + decoded.push_back(static_cast(value_i24 & 0xFF)); + } + } + return Bytes(decoded); +} + +/*! + * \brief Decode a base64 string into a byte array + * \param data The base64 encoded string to decode + * \return The decoded byte array + */ +inline Bytes Base64Decode(const String& data) { + return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); +} + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_BASE64_H_ diff --git a/ffi/include/tvm/ffi/extra/json.h b/ffi/include/tvm/ffi/extra/json.h index 847e60c0f694..409f7aa52560 100644 --- a/ffi/include/tvm/ffi/extra/json.h +++ b/ffi/include/tvm/ffi/extra/json.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file tvm/ffi/json/json.h + * \file tvm/ffi/extra/json.h * \brief Minimal lightweight JSON parsing and serialization utilities */ #ifndef TVM_FFI_EXTRA_JSON_H_ diff --git a/ffi/include/tvm/ffi/extra/serialization.h b/ffi/include/tvm/ffi/extra/serialization.h new file mode 100644 index 000000000000..c08ad81cc363 --- /dev/null +++ b/ffi/include/tvm/ffi/extra/serialization.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/serialization.h + * \brief Reflection-based serialization utilities + */ +#ifndef TVM_FFI_EXTRA_SERIALIZATION_H_ +#define TVM_FFI_EXTRA_SERIALIZATION_H_ + +#include +#include + +namespace tvm { +namespace ffi { + +/** + * \brief Serialize ffi::Any to a JSON that stores the object graph. + * + * The JSON graph structure is stored as follows: + * + * ```json + * { + * "root_index": , // Index of root node in nodes array + * "nodes": [, ...], // Array of serialized nodes + * "metadata": // Optional metadata + * } + * ``` + * + * Each node has the format: `{"type": "", "data": }` + * For object types and strings, the data may contain indices to other nodes. + * For object fields whose static type is known as a primitive type, it is stored directly, + * otherwise, it is stored as a reference to the nodes array by an index. + * + * This function preserves the type and multiple references to the same object, + * which is useful for debugging and serialization. + * + * \param value The ffi::Any value to serialize. + * \param metadata Extra metadata attached to "metadata" field of the JSON object. + * \return The serialized JSON value. + */ +TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any& metadata = Any(nullptr)); + +/** + * \brief Deserialize a JSON that stores the object graph to an ffi::Any value. + * + * This function can be used to implement deserialization + * and debugging. + * + * \param value The JSON value to deserialize. + * \return The deserialized object graph. + */ +TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value& value); + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ diff --git a/ffi/src/ffi/extra/serialization.cc b/ffi/src/ffi/extra/serialization.cc new file mode 100644 index 000000000000..b3230f38fb58 --- /dev/null +++ b/ffi/src/ffi/extra/serialization.cc @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/serialization.cc + * + * \brief Reflection-based serialization utilities. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +class ObjectGraphSerializer { + public: + static json::Value Serialize(const Any& value, Any metadata) { + ObjectGraphSerializer serializer; + json::Object result; + result.Set("root_index", serializer.GetOrCreateNodeIndex(value)); + result.Set("nodes", std::move(serializer.nodes_)); + if (metadata != nullptr) { + result.Set("metadata", metadata); + } + return result; + } + + private: + ObjectGraphSerializer() = default; + + int64_t GetOrCreateNodeIndex(const Any& value) { + // already mapped value, return the index + auto it = node_index_map_.find(value); + if (it != node_index_map_.end()) { + return (*it).second; + } + json::Object node; + switch (value.type_index()) { + case TypeIndex::kTVMFFINone: { + node.Set("type", ffi::StaticTypeKey::kTVMFFINone); + break; + } + case TypeIndex::kTVMFFIBool: { + node.Set("type", ffi::StaticTypeKey::kTVMFFIBool); + node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); + break; + } + case TypeIndex::kTVMFFIInt: { + node.Set("type", ffi::StaticTypeKey::kTVMFFIInt); + node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); + break; + } + case TypeIndex::kTVMFFIFloat: { + node.Set("type", ffi::StaticTypeKey::kTVMFFIFloat); + node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); + break; + } + case TypeIndex::kTVMFFIDataType: { + DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIDataType); + node.Set("data", DLDataTypeToString(dtype)); + break; + } + case TypeIndex::kTVMFFIDevice: { + DLDevice device = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIDevice); + node.Set("data", json::Array{ + static_cast(device.device_type), + static_cast(device.device_id), + }); + break; + } + case TypeIndex::kTVMFFISmallStr: + case TypeIndex::kTVMFFIStr: { + String str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIStr); + node.Set("data", str); + break; + } + case TypeIndex::kTVMFFISmallBytes: + case TypeIndex::kTVMFFIBytes: { + Bytes bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIBytes); + node.Set("data", Base64Encode(bytes)); + break; + } + case TypeIndex::kTVMFFIArray: { + Array array = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIArray); + node.Set("data", CreateArrayData(array)); + break; + } + case TypeIndex::kTVMFFIMap: { + Map map = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIMap); + node.Set("data", CreateMapData(map)); + break; + } + default: { + if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) { + // serialize type key since type index is runtime dependent + node.Set("type", value.GetTypeKey()); + node.Set("data", CreateObjectData(value)); + } else { + TVM_FFI_THROW(RuntimeError) << "Cannot serialize type `" << value.GetTypeKey() << "`"; + TVM_FFI_UNREACHABLE(); + } + } + } + int64_t node_index = nodes_.size(); + nodes_.push_back(node); + node_index_map_.Set(value, node_index); + return node_index; + } + + json::Array CreateArrayData(const Array& value) { + json::Array data; + data.reserve(value.size()); + for (const Any& item : value) { + data.push_back(GetOrCreateNodeIndex(item)); + } + return data; + } + + json::Array CreateMapData(const Map& value) { + json::Array data; + data.reserve(value.size() * 2); + for (const auto& [key, value] : value) { + data.push_back(GetOrCreateNodeIndex(key)); + data.push_back(GetOrCreateNodeIndex(value)); + } + return data; + } + + // create the data for the object, if the type has a custom data to json function, + // use it. otherwise, we go over the fields and create the data. + json::Object CreateObjectData(const Any& value) { + static reflection::TypeAttrColumn data_to_json = reflection::TypeAttrColumn("__data_to_json__"); + if (data_to_json[value.type_index()] != nullptr) { + return data_to_json[value.type_index()].cast()(value).cast(); + } + // NOTE: invariant: lhs and rhs are already the same type + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index()); + if (type_info->metadata == nullptr) { + TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" + << String(type_info->type_key) + << "`, so ToJSONGraph is not supported for this type"; + } + const Object* obj = value.cast(); + json::Object data; + // go over the content and hash the fields + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + // get the field value from both side + reflection::FieldGetter getter(field_info); + Any field_value = getter(obj); + int field_static_type_index = field_info->field_static_type_index; + String field_name(field_info->name); + // for static field index that are known, we can directly set the field value. + switch (field_static_type_index) { + case TypeIndex::kTVMFFINone: { + data.Set(field_name, nullptr); + break; + } + case TypeIndex::kTVMFFIBool: { + data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); + break; + } + case TypeIndex::kTVMFFIInt: { + data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); + break; + } + case TypeIndex::kTVMFFIFloat: { + data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); + break; + } + case TypeIndex::kTVMFFIDataType: { + DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value); + data.Set(field_name, DLDataTypeToString(dtype)); + break; + } + default: { + // for dynamic field index, we need need to put them onto nodes + int64_t node_index = GetOrCreateNodeIndex(field_value); + data.Set(field_name, node_index); + break; + } + } + }); + return data; + } + + // maps the original value to the index of the node in the nodes_ array + Map node_index_map_; + // records nodes that are serialized + json::Array nodes_; +}; + +json::Value ToJSONGraph(const Any& value, const Any& metadata) { + return ObjectGraphSerializer::Serialize(value, metadata); +} + +class ObjectGraphDeserializer { + public: + static Any Deserialize(const json::Value& value) { + ObjectGraphDeserializer deserializer(value); + return deserializer.GetOrDecodeNode(deserializer.root_index_); + } + + Any GetOrDecodeNode(int64_t node_index) { + // already decoded null index + if (node_index == decoded_null_index_) { + return Any(nullptr); + } + // already decoded + if (decoded_nodes_[node_index] != nullptr) { + return decoded_nodes_[node_index]; + } + // now decode the node + Any value = DecodeNode(nodes_[node_index].cast()); + decoded_nodes_[node_index] = value; + if (value == nullptr) { + decoded_null_index_ = node_index; + } + return value; + } + + private: + Any DecodeNode(const json::Object& node) { + String type_key = node["type"].cast(); + TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()}; + int32_t type_index; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); + + switch (type_index) { + case TypeIndex::kTVMFFINone: { + return nullptr; + } + case TypeIndex::kTVMFFIBool: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIInt: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIFloat: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIDataType: { + return StringToDLDataType(node["data"].cast()); + } + case TypeIndex::kTVMFFIDevice: { + Array data = node["data"].cast>(); + return DLDevice{static_cast(data[0]), data[1]}; + } + case TypeIndex::kTVMFFIStr: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIBytes: { + return Base64Decode(node["data"].cast()); + } + case TypeIndex::kTVMFFIMap: { + return DecodeMapData(node["data"].cast()); + } + case TypeIndex::kTVMFFIArray: { + return DecodeArrayData(node["data"].cast()); + } + default: { + return DecodeObjectData(type_index, node["data"]); + } + } + } + + Array DecodeArrayData(const json::Array& data) { + Array array; + array.reserve(data.size()); + for (size_t i = 0; i < data.size(); i++) { + array.push_back(GetOrDecodeNode(data[i].cast())); + } + return array; + } + + Map DecodeMapData(const json::Array& data) { + Map map; + for (size_t i = 0; i < data.size(); i += 2) { + int64_t key_index = data[i].cast(); + int64_t value_index = data[i + 1].cast(); + map.Set(GetOrDecodeNode(key_index), GetOrDecodeNode(value_index)); + } + return map; + } + + Any DecodeObjectData(int32_t type_index, const json::Value& data) { + static reflection::TypeAttrColumn data_from_json = + reflection::TypeAttrColumn("__data_from_json__"); + if (data_from_json[type_index] != nullptr) { + return data_from_json[type_index].cast()(data); + } + // otherwise, we go over the fields and create the data. + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not support default constructor" + << ", so ToJSONGraph is not supported for this type"; + } + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); + ObjectPtr ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + + auto decode_field_value = [&](const TVMFFIFieldInfo* field_info, json::Value data) -> Any { + switch (field_info->field_static_type_index) { + case TypeIndex::kTVMFFINone: { + return nullptr; + } + case TypeIndex::kTVMFFIBool: { + return data.cast(); + } + case TypeIndex::kTVMFFIInt: { + return data.cast(); + } + case TypeIndex::kTVMFFIFloat: { + return data.cast(); + } + case TypeIndex::kTVMFFIDataType: { + return StringToDLDataType(data.cast()); + } + default: { + return GetOrDecodeNode(data.cast()); + } + } + }; + + json::Object data_object = data.cast(); + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + String field_name(field_info->name); + void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; + if (data_object.count(field_name) != 0) { + Any field_value = decode_field_value(field_info, data_object[field_name]); + field_info->setter(field_addr, reinterpret_cast(&field_value)); + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; + } + }); + return ObjectRef(ptr); + } + + explicit ObjectGraphDeserializer(json::Value serialized) { + if (!serialized.as()) { + TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected an object"; + } + json::Object encoded_object = serialized.cast(); + if (encoded_object.count("root_index") == 0 || !encoded_object["root_index"].as()) { + TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `root_index` integer field"; + } + if (encoded_object.count("nodes") == 0 || !encoded_object["nodes"].as()) { + TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `nodes` array field"; + } + root_index_ = encoded_object["root_index"].cast(); + nodes_ = encoded_object["nodes"].cast(); + decoded_nodes_.resize(nodes_.size(), Any(nullptr)); + } + // nodes + json::Array nodes_; + // root index + int64_t root_index_; + // null index if already created + int64_t decoded_null_index_{-1}; + // decoded nodes + std::vector decoded_nodes_; +}; + +Any FromJSONGraph(const json::Value& value) { return ObjectGraphDeserializer::Deserialize(value); } + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ffi.ToJSONGraph", ToJSONGraph).def("ffi.FromJSONGraph", FromJSONGraph); + refl::EnsureTypeAttrColumn("__data_to_json__"); + refl::EnsureTypeAttrColumn("__data_from_json__"); +}); + +} // namespace ffi +} // namespace tvm diff --git a/ffi/tests/cpp/extra/test_serialization.cc b/ffi/tests/cpp/extra/test_serialization.cc new file mode 100644 index 000000000000..f0aefa370966 --- /dev/null +++ b/ffi/tests/cpp/extra/test_serialization.cc @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Serialization, BoolNull) { + json::Object expected_null = + json::Object{{"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "None"}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(nullptr), expected_null)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_null), nullptr)); + + json::Object expected_true = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(true), expected_true)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_true), true)); + + json::Object expected_false = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", false}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(false), expected_false)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_false), false)); +} + +TEST(Serialization, IntegerTypes) { + // Test positive integer + json::Object expected_int = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "int"}, {"data", 42}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(static_cast(42)), expected_int)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int), static_cast(42))); +} + +TEST(Serialization, FloatTypes) { + // Test positive float + json::Object expected_float = + json::Object{{"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "float"}, {"data", 3.14159}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(3.14159), expected_float)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float), 3.14159)); +} + +TEST(Serialization, StringTypes) { + // Test short string + json::Object expected_short = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("hello")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello")), expected_short)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_short), String("hello"))); + + // Test long string + std::string long_str(1000, 'x'); + json::Object expected_long = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String(long_str)}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String(long_str)), expected_long)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_long), String(long_str))); + + // Test string with special characters + json::Object expected_special = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, + {"data", String("hello\nworld\t\"quotes\"")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello\nworld\t\"quotes\"")), expected_special)); + EXPECT_TRUE( + StructuralEqual()(FromJSONGraph(expected_special), String("hello\nworld\t\"quotes\""))); +} + +TEST(Serialization, Bytes) { + // Test empty bytes + Bytes empty_bytes; + json::Object expected_empty = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", ""}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_bytes), expected_empty)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_bytes)); + + // Test bytes with that encoded as base64 + Bytes bytes_content = Bytes("abcd"); + json::Object expected_encoded = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "YWJjZA=="}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_content), expected_encoded)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded), bytes_content)); + + // Test bytes with that encoded as base64, that contains control characters via utf-8 + char bytes_v2_content[] = {0x01, 0x02, 0x03, 0x04, 0x01, 0x0b}; + Bytes bytes_v2 = Bytes(bytes_v2_content, sizeof(bytes_v2_content)); + json::Object expected_encoded_v2 = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "AQIDBAEL"}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_v2), expected_encoded_v2)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded_v2), bytes_v2)); +} + +TEST(Serialization, DataTypes) { + // Test int32 dtype + DLDataType int32_dtype; + int32_dtype.code = kDLInt; + int32_dtype.bits = 32; + int32_dtype.lanes = 1; + + json::Object expected_int32 = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("int32")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(int32_dtype), expected_int32)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int32), int32_dtype)); + + // Test float64 dtype + DLDataType float64_dtype; + float64_dtype.code = kDLFloat; + float64_dtype.bits = 64; + float64_dtype.lanes = 1; + + json::Object expected_float64 = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float64")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(float64_dtype), expected_float64)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float64), float64_dtype)); + + // Test vector dtype + DLDataType vector_dtype; + vector_dtype.code = kDLFloat; + vector_dtype.bits = 32; + vector_dtype.lanes = 4; + + json::Object expected_vector = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float32x4")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(vector_dtype), expected_vector)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_vector), vector_dtype)); +} + +TEST(Serialization, DeviceTypes) { + // Test CPU device + DLDevice cpu_device; + cpu_device.device_type = kDLCPU; + cpu_device.device_id = 0; + + json::Object expected_cpu = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "Device"}, + {"data", json::Array{static_cast(kDLCPU), + static_cast(0)}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(cpu_device), expected_cpu)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_cpu), cpu_device)); + + // Test GPU device + DLDevice gpu_device; + gpu_device.device_type = kDLCUDA; + gpu_device.device_id = 1; + + json::Object expected_gpu = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{ + {"type", "Device"}, {"data", json::Array{static_cast(kDLCUDA), 1}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(gpu_device), expected_gpu)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_gpu), gpu_device)); +} + +TEST(Serialization, Arrays) { + // Test empty array + Array empty_array; + json::Object expected_empty = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_array), expected_empty)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_array)); + + // Test single element array + Array single_array; + single_array.push_back(Any(42)); + json::Object expected_single = + json::Object{{"root_index", 1}, + {"nodes", json::Array{ + json::Object{{"type", "int"}, {"data", static_cast(42)}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{0}}}, + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_array), expected_single)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_array)); + + // Test duplicated element array + Array duplicated_array; + duplicated_array.push_back(42); + duplicated_array.push_back(42); + json::Object expected_duplicated = + json::Object{{"root_index", 1}, + {"nodes", json::Array{ + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 0}}}, + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_array), expected_duplicated)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_array)); + // Test mixed element array, note that 42 and "hello" are duplicated and will + // be indexed as 0 and 1 + Array mixed_array; + mixed_array.push_back(42); + mixed_array.push_back(String("hello")); + mixed_array.push_back(true); + mixed_array.push_back(nullptr); + mixed_array.push_back(42); + mixed_array.push_back(String("hello")); + json::Object expected_mixed = json::Object{ + {"root_index", 4}, + {"nodes", json::Array{ + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.String"}, {"data", String("hello")}}, + json::Object{{"type", "bool"}, {"data", true}}, + json::Object{{"type", "None"}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 1, 2, 3, 0, 1}}}, + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(mixed_array), expected_mixed)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_mixed), mixed_array)); +} + +TEST(Serialization, Maps) { + // Test empty map + Map empty_map; + json::Object expected_empty = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Map"}, {"data", json::Array{}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_map), expected_empty)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_map)); + + // Test single element map + Map single_map{{"key", 42}}; + json::Object expected_single = json::Object{ + {"root_index", 2}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("key")}}, + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_map), expected_single)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_map)); + + // Test duplicated element map + Map duplicated_map{{"b", 42}, {"a", 42}}; + json::Object expected_duplicated = json::Object{ + {"root_index", 3}, + {"nodes", json::Array{ + json::Object{{"type", "ffi.String"}, {"data", "b"}}, + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.String"}, {"data", "a"}}, + json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1, 2, 1}}}, + + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_map), expected_duplicated)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_map)); +} + +TEST(Serialization, TestObjectVar) { + TVar x = TVar("x"); + json::Object expected_x = json::Object{ + {"root_index", 1}, + {"nodes", + json::Array{json::Object{{"type", "ffi.String"}, {"data", "x"}}, + json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(x), expected_x)); + EXPECT_TRUE(StructuralEqual::Equal(FromJSONGraph(expected_x), x, /*map_free_vars=*/true)); +} + +TEST(Serialization, TestObjectIntCustomToJSON) { + TInt value = TInt(42); + json::Object expected_i = json::Object{ + {"root_index", 0}, + {"nodes", + json::Array{json::Object{{"type", "test.Int"}, {"data", json::Object{{"value", 42}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value), expected_i)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_i), value)); +} + +TEST(Serialization, TestObjectFunc) { + TVar x = TVar("x"); + // comment fields are ignored + TFunc fa = TFunc({x}, {x, x}, String("comment a")); + + json::Object expected_fa = json::Object{ + {"root_index", 5}, + {"nodes", + json::Array{ + json::Object{{"type", "ffi.String"}, {"data", "x"}}, // string "x" + json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}, // var x + json::Object{{"type", "ffi.Array"}, {"data", json::Array{1}}}, // array [x] + json::Object{{"type", "ffi.Array"}, {"data", json::Array{1, 1}}}, // array [x, x] + json::Object{{"type", "ffi.String"}, {"data", "comment a"}}, // "comment a" + json::Object{{"type", "test.Func"}, + {"data", json::Object{{"params", 2}, {"body", 3}, {"comment", 4}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fa), expected_fa)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fa), fa)); + + TFunc fb = TFunc({}, {}, std::nullopt); + json::Object expected_fb = json::Object{ + {"root_index", 3}, + {"nodes", + json::Array{ + json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, + json::Object{{"type", "None"}}, + json::Object{{"type", "test.Func"}, + {"data", json::Object{{"params", 0}, {"body", 1}, {"comment", 2}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fb), expected_fb)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fb), fb)); +} + +TEST(Serialization, AttachMetadata) { + bool value = true; + json::Object metadata{{"version", "1.0"}}; + json::Object expected = + json::Object{{"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}, + {"metadata", metadata}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value, metadata), expected)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected), value)); +} + +TEST(Serialization, ShuffleNodeOrder) { + // the FromJSONGraph is agnostic to the node order + // so we can shuffle the node order as it reads nodes lazily + Map duplicated_map{{"b", 42}, {"a", 42}}; + json::Object expected_shuffled = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{ + json::Object{{"type", "ffi.Map"}, {"data", json::Array{2, 3, 1, 3}}}, + json::Object{{"type", "ffi.String"}, {"data", "a"}}, + json::Object{{"type", "ffi.String"}, {"data", "b"}}, + json::Object{{"type", "int"}, {"data", 42}}, + }}}; + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shuffled), duplicated_map)); +} + +} // namespace diff --git a/ffi/tests/cpp/extra/test_structural_equal_hash.cc b/ffi/tests/cpp/extra/test_structural_equal_hash.cc index 76c485d9062e..8a377f483713 100644 --- a/ffi/tests/cpp/extra/test_structural_equal_hash.cc +++ b/ffi/tests/cpp/extra/test_structural_equal_hash.cc @@ -147,10 +147,10 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) { TVar x = TVar("x"); TVar y = TVar("y"); // comment fields are ignored - TFunc fa = TFunc({x}, {TInt(1), x}, "comment a"); - TFunc fb = TFunc({y}, {TInt(1), y}, "comment b"); + TFunc fa = TFunc({x}, {TInt(1), x}, String("comment a")); + TFunc fb = TFunc({y}, {TInt(1), y}, String("comment b")); - TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, "comment c"); + TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, String("comment c")); EXPECT_TRUE(StructuralEqual()(fa, fb)); EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index 78ca008e1094..c5725da941f5 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -21,6 +21,7 @@ #define TVM_FFI_TESTING_OBJECT_H_ #include +#include #include #include #include @@ -87,6 +88,15 @@ inline void TIntObj::RegisterReflection() { refl::TypeAttrDef() .def("test.GetValue", &TIntObj::GetValue) .attr("test.size", sizeof(TIntObj)); + // custom json serialization + refl::TypeAttrDef() + .def("__data_to_json__", + [](const TIntObj* self) -> Map { + return Map{{"value", self->value}}; + }) + .def("__data_from_json__", [](Map json_obj) -> TInt { + return TInt(json_obj["value"].cast()); + }); } class TFloatObj : public TNumberObj { @@ -154,6 +164,8 @@ class TVarObj : public Object { public: std::string name; + // need default constructor for json serialization + TVarObj() = default; TVarObj(std::string name) : name(name) {} static void RegisterReflection() { @@ -178,9 +190,11 @@ class TFuncObj : public Object { public: Array params; Array body; - String comment; + Optional comment; - TFuncObj(Array params, Array body, String comment) + // need default constructor for json serialization + TFuncObj() = default; + TFuncObj(Array params, Array body, Optional comment) : params(params), body(body), comment(comment) {} static void RegisterReflection() { @@ -198,7 +212,7 @@ class TFuncObj : public Object { class TFunc : public ObjectRef { public: - explicit TFunc(Array params, Array body, String comment) { + explicit TFunc(Array params, Array body, Optional comment) { data_ = make_object(params, body, comment); }