From 748537ebc281f2ece007501bc2550e7bfc0622f0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 4 Aug 2025 10:55:43 -0400 Subject: [PATCH] [FFI] Serialization To/From JSONGraph This PR implements serialization function for generic ffi::Any based on the reflection that is preserves the overall object graph reference relation. These extra APIs are implemented through the reflection system. They can be used to further modernize and unify the serialization mechanisms in the project under the new reflection mechanism. --- ffi/CMakeLists.txt | 1 + ffi/include/tvm/ffi/extra/base64.h | 142 ++++++ ffi/include/tvm/ffi/extra/json.h | 2 +- ffi/include/tvm/ffi/extra/serialization.h | 72 ++++ ffi/src/ffi/extra/serialization.cc | 408 ++++++++++++++++++ ffi/tests/cpp/extra/test_serialization.cc | 354 +++++++++++++++ .../cpp/extra/test_structural_equal_hash.cc | 6 +- ffi/tests/cpp/testing_object.h | 20 +- 8 files changed, 998 insertions(+), 7 deletions(-) create mode 100644 ffi/include/tvm/ffi/extra/base64.h create mode 100644 ffi/include/tvm/ffi/extra/serialization.h create mode 100644 ffi/src/ffi/extra/serialization.cc create mode 100644 ffi/tests/cpp/extra/test_serialization.cc diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index b67611a273aa..55fbd1c1bcf4 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -68,6 +68,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API) "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc" ) endif() diff --git a/ffi/include/tvm/ffi/extra/base64.h b/ffi/include/tvm/ffi/extra/base64.h new file mode 100644 index 000000000000..136fec2e7f84 --- /dev/null +++ b/ffi/include/tvm/ffi/extra/base64.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * + * \file tvm/ffi/extra/base64.h + * \brief Base64 encoding and decoding utilities + */ +#ifndef TVM_FFI_EXTRA_BASE64_H_ +#define TVM_FFI_EXTRA_BASE64_H_ + +#include + +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Encode a byte array into a base64 string + * \param bytes The byte array to encode + * \return The base64 encoded string + */ +inline String Base64Encode(TVMFFIByteArray bytes) { + // encoding every 3 bytes into 4 characters + constexpr const char kEncodeTable[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string encoded; + encoded.reserve(4 * (bytes.size + 2) / 3); + + for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { + int32_t buf[3]; + buf[0] = static_cast(bytes.data[i]); + buf[1] = static_cast(bytes.data[i + 1]); + buf[2] = static_cast(bytes.data[i + 2]); + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); + encoded.push_back(kEncodeTable[buf[2] & 0x3F]); + } + if (bytes.size % 3 == 1) { + int32_t buf[1] = {static_cast(bytes.data[bytes.size - 1])}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); + encoded.push_back('='); + encoded.push_back('='); + } else if (bytes.size % 3 == 2) { + int32_t buf[2] = {static_cast(bytes.data[bytes.size - 2]), + static_cast(bytes.data[bytes.size - 1])}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); + encoded.push_back('='); + } + return String(encoded); +} + +/*! + * \brief Encode a bytes object into a base64 string + * \param data The bytes object to encode + * \return The base64 encoded string + */ +inline String Base64Encode(const Bytes& data) { + return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); +} + +/*! + * \brief Decode a base64 string into a byte array + * \param data The base64 encoded string to decode + * \return The decoded byte array + */ +inline Bytes Base64Decode(TVMFFIByteArray bytes) { + constexpr const char kDecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + }; + std::string decoded; + decoded.reserve(bytes.size * 3 / 4); + if (bytes.size == 0) return Bytes(); + TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; + // leverage this property to simplify decoding + static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); + // base64 is always multiple of 4 bytes + for (size_t i = 0; i < bytes.size; i += 4) { + // decode every 4 characters into 24bits, each character contains 6 bits + // note that = is also decoded as 0, which is safe to skip + int32_t buf[4] = { + static_cast(bytes.data[i]), + static_cast(bytes.data[i + 1]), + static_cast(bytes.data[i + 2]), + static_cast(bytes.data[i + 3]), + }; + int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | + (static_cast(kDecodeTable[buf[1]]) << 12) | + (static_cast(kDecodeTable[buf[2]]) << 6) | + static_cast(kDecodeTable[buf[3]]); + // unpack 24bits into 3 bytes, each contains 8 bits + decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); + if (buf[2] != '=') { + decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); + } + if (buf[3] != '=') { + decoded.push_back(static_cast(value_i24 & 0xFF)); + } + } + return Bytes(decoded); +} + +/*! + * \brief Decode a base64 string into a byte array + * \param data The base64 encoded string to decode + * \return The decoded byte array + */ +inline Bytes Base64Decode(const String& data) { + return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); +} + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_BASE64_H_ diff --git a/ffi/include/tvm/ffi/extra/json.h b/ffi/include/tvm/ffi/extra/json.h index 847e60c0f694..409f7aa52560 100644 --- a/ffi/include/tvm/ffi/extra/json.h +++ b/ffi/include/tvm/ffi/extra/json.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file tvm/ffi/json/json.h + * \file tvm/ffi/extra/json.h * \brief Minimal lightweight JSON parsing and serialization utilities */ #ifndef TVM_FFI_EXTRA_JSON_H_ diff --git a/ffi/include/tvm/ffi/extra/serialization.h b/ffi/include/tvm/ffi/extra/serialization.h new file mode 100644 index 000000000000..c08ad81cc363 --- /dev/null +++ b/ffi/include/tvm/ffi/extra/serialization.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/serialization.h + * \brief Reflection-based serialization utilities + */ +#ifndef TVM_FFI_EXTRA_SERIALIZATION_H_ +#define TVM_FFI_EXTRA_SERIALIZATION_H_ + +#include +#include + +namespace tvm { +namespace ffi { + +/** + * \brief Serialize ffi::Any to a JSON that stores the object graph. + * + * The JSON graph structure is stored as follows: + * + * ```json + * { + * "root_index": , // Index of root node in nodes array + * "nodes": [, ...], // Array of serialized nodes + * "metadata": // Optional metadata + * } + * ``` + * + * Each node has the format: `{"type": "", "data": }` + * For object types and strings, the data may contain indices to other nodes. + * For object fields whose static type is known as a primitive type, it is stored directly, + * otherwise, it is stored as a reference to the nodes array by an index. + * + * This function preserves the type and multiple references to the same object, + * which is useful for debugging and serialization. + * + * \param value The ffi::Any value to serialize. + * \param metadata Extra metadata attached to "metadata" field of the JSON object. + * \return The serialized JSON value. + */ +TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any& metadata = Any(nullptr)); + +/** + * \brief Deserialize a JSON that stores the object graph to an ffi::Any value. + * + * This function can be used to implement deserialization + * and debugging. + * + * \param value The JSON value to deserialize. + * \return The deserialized object graph. + */ +TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value& value); + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ diff --git a/ffi/src/ffi/extra/serialization.cc b/ffi/src/ffi/extra/serialization.cc new file mode 100644 index 000000000000..b3230f38fb58 --- /dev/null +++ b/ffi/src/ffi/extra/serialization.cc @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/serialization.cc + * + * \brief Reflection-based serialization utilities. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +class ObjectGraphSerializer { + public: + static json::Value Serialize(const Any& value, Any metadata) { + ObjectGraphSerializer serializer; + json::Object result; + result.Set("root_index", serializer.GetOrCreateNodeIndex(value)); + result.Set("nodes", std::move(serializer.nodes_)); + if (metadata != nullptr) { + result.Set("metadata", metadata); + } + return result; + } + + private: + ObjectGraphSerializer() = default; + + int64_t GetOrCreateNodeIndex(const Any& value) { + // already mapped value, return the index + auto it = node_index_map_.find(value); + if (it != node_index_map_.end()) { + return (*it).second; + } + json::Object node; + switch (value.type_index()) { + case TypeIndex::kTVMFFINone: { + node.Set("type", ffi::StaticTypeKey::kTVMFFINone); + break; + } + case TypeIndex::kTVMFFIBool: { + node.Set("type", ffi::StaticTypeKey::kTVMFFIBool); + node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); + break; + } + case TypeIndex::kTVMFFIInt: { + node.Set("type", ffi::StaticTypeKey::kTVMFFIInt); + node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); + break; + } + case TypeIndex::kTVMFFIFloat: { + node.Set("type", ffi::StaticTypeKey::kTVMFFIFloat); + node.Set("data", details::AnyUnsafe::CopyFromAnyViewAfterCheck(value)); + break; + } + case TypeIndex::kTVMFFIDataType: { + DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIDataType); + node.Set("data", DLDataTypeToString(dtype)); + break; + } + case TypeIndex::kTVMFFIDevice: { + DLDevice device = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIDevice); + node.Set("data", json::Array{ + static_cast(device.device_type), + static_cast(device.device_id), + }); + break; + } + case TypeIndex::kTVMFFISmallStr: + case TypeIndex::kTVMFFIStr: { + String str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIStr); + node.Set("data", str); + break; + } + case TypeIndex::kTVMFFISmallBytes: + case TypeIndex::kTVMFFIBytes: { + Bytes bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIBytes); + node.Set("data", Base64Encode(bytes)); + break; + } + case TypeIndex::kTVMFFIArray: { + Array array = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIArray); + node.Set("data", CreateArrayData(array)); + break; + } + case TypeIndex::kTVMFFIMap: { + Map map = details::AnyUnsafe::CopyFromAnyViewAfterCheck>(value); + node.Set("type", ffi::StaticTypeKey::kTVMFFIMap); + node.Set("data", CreateMapData(map)); + break; + } + default: { + if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) { + // serialize type key since type index is runtime dependent + node.Set("type", value.GetTypeKey()); + node.Set("data", CreateObjectData(value)); + } else { + TVM_FFI_THROW(RuntimeError) << "Cannot serialize type `" << value.GetTypeKey() << "`"; + TVM_FFI_UNREACHABLE(); + } + } + } + int64_t node_index = nodes_.size(); + nodes_.push_back(node); + node_index_map_.Set(value, node_index); + return node_index; + } + + json::Array CreateArrayData(const Array& value) { + json::Array data; + data.reserve(value.size()); + for (const Any& item : value) { + data.push_back(GetOrCreateNodeIndex(item)); + } + return data; + } + + json::Array CreateMapData(const Map& value) { + json::Array data; + data.reserve(value.size() * 2); + for (const auto& [key, value] : value) { + data.push_back(GetOrCreateNodeIndex(key)); + data.push_back(GetOrCreateNodeIndex(value)); + } + return data; + } + + // create the data for the object, if the type has a custom data to json function, + // use it. otherwise, we go over the fields and create the data. + json::Object CreateObjectData(const Any& value) { + static reflection::TypeAttrColumn data_to_json = reflection::TypeAttrColumn("__data_to_json__"); + if (data_to_json[value.type_index()] != nullptr) { + return data_to_json[value.type_index()].cast()(value).cast(); + } + // NOTE: invariant: lhs and rhs are already the same type + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(value.type_index()); + if (type_info->metadata == nullptr) { + TVM_FFI_THROW(TypeError) << "Type metadata is not set for type `" + << String(type_info->type_key) + << "`, so ToJSONGraph is not supported for this type"; + } + const Object* obj = value.cast(); + json::Object data; + // go over the content and hash the fields + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + // get the field value from both side + reflection::FieldGetter getter(field_info); + Any field_value = getter(obj); + int field_static_type_index = field_info->field_static_type_index; + String field_name(field_info->name); + // for static field index that are known, we can directly set the field value. + switch (field_static_type_index) { + case TypeIndex::kTVMFFINone: { + data.Set(field_name, nullptr); + break; + } + case TypeIndex::kTVMFFIBool: { + data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); + break; + } + case TypeIndex::kTVMFFIInt: { + data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); + break; + } + case TypeIndex::kTVMFFIFloat: { + data.Set(field_name, details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value)); + break; + } + case TypeIndex::kTVMFFIDataType: { + DLDataType dtype = details::AnyUnsafe::CopyFromAnyViewAfterCheck(field_value); + data.Set(field_name, DLDataTypeToString(dtype)); + break; + } + default: { + // for dynamic field index, we need need to put them onto nodes + int64_t node_index = GetOrCreateNodeIndex(field_value); + data.Set(field_name, node_index); + break; + } + } + }); + return data; + } + + // maps the original value to the index of the node in the nodes_ array + Map node_index_map_; + // records nodes that are serialized + json::Array nodes_; +}; + +json::Value ToJSONGraph(const Any& value, const Any& metadata) { + return ObjectGraphSerializer::Serialize(value, metadata); +} + +class ObjectGraphDeserializer { + public: + static Any Deserialize(const json::Value& value) { + ObjectGraphDeserializer deserializer(value); + return deserializer.GetOrDecodeNode(deserializer.root_index_); + } + + Any GetOrDecodeNode(int64_t node_index) { + // already decoded null index + if (node_index == decoded_null_index_) { + return Any(nullptr); + } + // already decoded + if (decoded_nodes_[node_index] != nullptr) { + return decoded_nodes_[node_index]; + } + // now decode the node + Any value = DecodeNode(nodes_[node_index].cast()); + decoded_nodes_[node_index] = value; + if (value == nullptr) { + decoded_null_index_ = node_index; + } + return value; + } + + private: + Any DecodeNode(const json::Object& node) { + String type_key = node["type"].cast(); + TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()}; + int32_t type_index; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); + + switch (type_index) { + case TypeIndex::kTVMFFINone: { + return nullptr; + } + case TypeIndex::kTVMFFIBool: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIInt: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIFloat: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIDataType: { + return StringToDLDataType(node["data"].cast()); + } + case TypeIndex::kTVMFFIDevice: { + Array data = node["data"].cast>(); + return DLDevice{static_cast(data[0]), data[1]}; + } + case TypeIndex::kTVMFFIStr: { + return node["data"].cast(); + } + case TypeIndex::kTVMFFIBytes: { + return Base64Decode(node["data"].cast()); + } + case TypeIndex::kTVMFFIMap: { + return DecodeMapData(node["data"].cast()); + } + case TypeIndex::kTVMFFIArray: { + return DecodeArrayData(node["data"].cast()); + } + default: { + return DecodeObjectData(type_index, node["data"]); + } + } + } + + Array DecodeArrayData(const json::Array& data) { + Array array; + array.reserve(data.size()); + for (size_t i = 0; i < data.size(); i++) { + array.push_back(GetOrDecodeNode(data[i].cast())); + } + return array; + } + + Map DecodeMapData(const json::Array& data) { + Map map; + for (size_t i = 0; i < data.size(); i += 2) { + int64_t key_index = data[i].cast(); + int64_t value_index = data[i + 1].cast(); + map.Set(GetOrDecodeNode(key_index), GetOrDecodeNode(value_index)); + } + return map; + } + + Any DecodeObjectData(int32_t type_index, const json::Value& data) { + static reflection::TypeAttrColumn data_from_json = + reflection::TypeAttrColumn("__data_from_json__"); + if (data_from_json[type_index] != nullptr) { + return data_from_json[type_index].cast()(data); + } + // otherwise, we go over the fields and create the data. + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not support default constructor" + << ", so ToJSONGraph is not supported for this type"; + } + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); + ObjectPtr ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + + auto decode_field_value = [&](const TVMFFIFieldInfo* field_info, json::Value data) -> Any { + switch (field_info->field_static_type_index) { + case TypeIndex::kTVMFFINone: { + return nullptr; + } + case TypeIndex::kTVMFFIBool: { + return data.cast(); + } + case TypeIndex::kTVMFFIInt: { + return data.cast(); + } + case TypeIndex::kTVMFFIFloat: { + return data.cast(); + } + case TypeIndex::kTVMFFIDataType: { + return StringToDLDataType(data.cast()); + } + default: { + return GetOrDecodeNode(data.cast()); + } + } + }; + + json::Object data_object = data.cast(); + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + String field_name(field_info->name); + void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; + if (data_object.count(field_name) != 0) { + Any field_value = decode_field_value(field_info, data_object[field_name]); + field_info->setter(field_addr, reinterpret_cast(&field_value)); + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" << TypeIndexToTypeKey(type_index) << "`"; + } + }); + return ObjectRef(ptr); + } + + explicit ObjectGraphDeserializer(json::Value serialized) { + if (!serialized.as()) { + TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected an object"; + } + json::Object encoded_object = serialized.cast(); + if (encoded_object.count("root_index") == 0 || !encoded_object["root_index"].as()) { + TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `root_index` integer field"; + } + if (encoded_object.count("nodes") == 0 || !encoded_object["nodes"].as()) { + TVM_FFI_THROW(ValueError) << "Invalid JSON Object Graph, expected `nodes` array field"; + } + root_index_ = encoded_object["root_index"].cast(); + nodes_ = encoded_object["nodes"].cast(); + decoded_nodes_.resize(nodes_.size(), Any(nullptr)); + } + // nodes + json::Array nodes_; + // root index + int64_t root_index_; + // null index if already created + int64_t decoded_null_index_{-1}; + // decoded nodes + std::vector decoded_nodes_; +}; + +Any FromJSONGraph(const json::Value& value) { return ObjectGraphDeserializer::Deserialize(value); } + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ffi.ToJSONGraph", ToJSONGraph).def("ffi.FromJSONGraph", FromJSONGraph); + refl::EnsureTypeAttrColumn("__data_to_json__"); + refl::EnsureTypeAttrColumn("__data_from_json__"); +}); + +} // namespace ffi +} // namespace tvm diff --git a/ffi/tests/cpp/extra/test_serialization.cc b/ffi/tests/cpp/extra/test_serialization.cc new file mode 100644 index 000000000000..f0aefa370966 --- /dev/null +++ b/ffi/tests/cpp/extra/test_serialization.cc @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Serialization, BoolNull) { + json::Object expected_null = + json::Object{{"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "None"}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(nullptr), expected_null)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_null), nullptr)); + + json::Object expected_true = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(true), expected_true)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_true), true)); + + json::Object expected_false = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", false}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(false), expected_false)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_false), false)); +} + +TEST(Serialization, IntegerTypes) { + // Test positive integer + json::Object expected_int = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "int"}, {"data", 42}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(static_cast(42)), expected_int)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int), static_cast(42))); +} + +TEST(Serialization, FloatTypes) { + // Test positive float + json::Object expected_float = + json::Object{{"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "float"}, {"data", 3.14159}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(3.14159), expected_float)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float), 3.14159)); +} + +TEST(Serialization, StringTypes) { + // Test short string + json::Object expected_short = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("hello")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello")), expected_short)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_short), String("hello"))); + + // Test long string + std::string long_str(1000, 'x'); + json::Object expected_long = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String(long_str)}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String(long_str)), expected_long)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_long), String(long_str))); + + // Test string with special characters + json::Object expected_special = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, + {"data", String("hello\nworld\t\"quotes\"")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(String("hello\nworld\t\"quotes\"")), expected_special)); + EXPECT_TRUE( + StructuralEqual()(FromJSONGraph(expected_special), String("hello\nworld\t\"quotes\""))); +} + +TEST(Serialization, Bytes) { + // Test empty bytes + Bytes empty_bytes; + json::Object expected_empty = json::Object{ + {"root_index", 0}, {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", ""}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_bytes), expected_empty)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_bytes)); + + // Test bytes with that encoded as base64 + Bytes bytes_content = Bytes("abcd"); + json::Object expected_encoded = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "YWJjZA=="}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_content), expected_encoded)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded), bytes_content)); + + // Test bytes with that encoded as base64, that contains control characters via utf-8 + char bytes_v2_content[] = {0x01, 0x02, 0x03, 0x04, 0x01, 0x0b}; + Bytes bytes_v2 = Bytes(bytes_v2_content, sizeof(bytes_v2_content)); + json::Object expected_encoded_v2 = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Bytes"}, {"data", "AQIDBAEL"}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(bytes_v2), expected_encoded_v2)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_encoded_v2), bytes_v2)); +} + +TEST(Serialization, DataTypes) { + // Test int32 dtype + DLDataType int32_dtype; + int32_dtype.code = kDLInt; + int32_dtype.bits = 32; + int32_dtype.lanes = 1; + + json::Object expected_int32 = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("int32")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(int32_dtype), expected_int32)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_int32), int32_dtype)); + + // Test float64 dtype + DLDataType float64_dtype; + float64_dtype.code = kDLFloat; + float64_dtype.bits = 64; + float64_dtype.lanes = 1; + + json::Object expected_float64 = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float64")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(float64_dtype), expected_float64)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_float64), float64_dtype)); + + // Test vector dtype + DLDataType vector_dtype; + vector_dtype.code = kDLFloat; + vector_dtype.bits = 32; + vector_dtype.lanes = 4; + + json::Object expected_vector = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "DataType"}, {"data", String("float32x4")}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(vector_dtype), expected_vector)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_vector), vector_dtype)); +} + +TEST(Serialization, DeviceTypes) { + // Test CPU device + DLDevice cpu_device; + cpu_device.device_type = kDLCPU; + cpu_device.device_id = 0; + + json::Object expected_cpu = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "Device"}, + {"data", json::Array{static_cast(kDLCPU), + static_cast(0)}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(cpu_device), expected_cpu)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_cpu), cpu_device)); + + // Test GPU device + DLDevice gpu_device; + gpu_device.device_type = kDLCUDA; + gpu_device.device_id = 1; + + json::Object expected_gpu = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{ + {"type", "Device"}, {"data", json::Array{static_cast(kDLCUDA), 1}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(gpu_device), expected_gpu)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_gpu), gpu_device)); +} + +TEST(Serialization, Arrays) { + // Test empty array + Array empty_array; + json::Object expected_empty = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_array), expected_empty)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_array)); + + // Test single element array + Array single_array; + single_array.push_back(Any(42)); + json::Object expected_single = + json::Object{{"root_index", 1}, + {"nodes", json::Array{ + json::Object{{"type", "int"}, {"data", static_cast(42)}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{0}}}, + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_array), expected_single)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_array)); + + // Test duplicated element array + Array duplicated_array; + duplicated_array.push_back(42); + duplicated_array.push_back(42); + json::Object expected_duplicated = + json::Object{{"root_index", 1}, + {"nodes", json::Array{ + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 0}}}, + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_array), expected_duplicated)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_array)); + // Test mixed element array, note that 42 and "hello" are duplicated and will + // be indexed as 0 and 1 + Array mixed_array; + mixed_array.push_back(42); + mixed_array.push_back(String("hello")); + mixed_array.push_back(true); + mixed_array.push_back(nullptr); + mixed_array.push_back(42); + mixed_array.push_back(String("hello")); + json::Object expected_mixed = json::Object{ + {"root_index", 4}, + {"nodes", json::Array{ + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.String"}, {"data", String("hello")}}, + json::Object{{"type", "bool"}, {"data", true}}, + json::Object{{"type", "None"}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{0, 1, 2, 3, 0, 1}}}, + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(mixed_array), expected_mixed)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_mixed), mixed_array)); +} + +TEST(Serialization, Maps) { + // Test empty map + Map empty_map; + json::Object expected_empty = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "ffi.Map"}, {"data", json::Array{}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_map), expected_empty)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_map)); + + // Test single element map + Map single_map{{"key", 42}}; + json::Object expected_single = json::Object{ + {"root_index", 2}, + {"nodes", json::Array{json::Object{{"type", "ffi.String"}, {"data", String("key")}}, + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_map), expected_single)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_map)); + + // Test duplicated element map + Map duplicated_map{{"b", 42}, {"a", 42}}; + json::Object expected_duplicated = json::Object{ + {"root_index", 3}, + {"nodes", json::Array{ + json::Object{{"type", "ffi.String"}, {"data", "b"}}, + json::Object{{"type", "int"}, {"data", 42}}, + json::Object{{"type", "ffi.String"}, {"data", "a"}}, + json::Object{{"type", "ffi.Map"}, {"data", json::Array{0, 1, 2, 1}}}, + + }}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(duplicated_map), expected_duplicated)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_duplicated), duplicated_map)); +} + +TEST(Serialization, TestObjectVar) { + TVar x = TVar("x"); + json::Object expected_x = json::Object{ + {"root_index", 1}, + {"nodes", + json::Array{json::Object{{"type", "ffi.String"}, {"data", "x"}}, + json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(x), expected_x)); + EXPECT_TRUE(StructuralEqual::Equal(FromJSONGraph(expected_x), x, /*map_free_vars=*/true)); +} + +TEST(Serialization, TestObjectIntCustomToJSON) { + TInt value = TInt(42); + json::Object expected_i = json::Object{ + {"root_index", 0}, + {"nodes", + json::Array{json::Object{{"type", "test.Int"}, {"data", json::Object{{"value", 42}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value), expected_i)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_i), value)); +} + +TEST(Serialization, TestObjectFunc) { + TVar x = TVar("x"); + // comment fields are ignored + TFunc fa = TFunc({x}, {x, x}, String("comment a")); + + json::Object expected_fa = json::Object{ + {"root_index", 5}, + {"nodes", + json::Array{ + json::Object{{"type", "ffi.String"}, {"data", "x"}}, // string "x" + json::Object{{"type", "test.Var"}, {"data", json::Object{{"name", 0}}}}, // var x + json::Object{{"type", "ffi.Array"}, {"data", json::Array{1}}}, // array [x] + json::Object{{"type", "ffi.Array"}, {"data", json::Array{1, 1}}}, // array [x, x] + json::Object{{"type", "ffi.String"}, {"data", "comment a"}}, // "comment a" + json::Object{{"type", "test.Func"}, + {"data", json::Object{{"params", 2}, {"body", 3}, {"comment", 4}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fa), expected_fa)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fa), fa)); + + TFunc fb = TFunc({}, {}, std::nullopt); + json::Object expected_fb = json::Object{ + {"root_index", 3}, + {"nodes", + json::Array{ + json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, + json::Object{{"type", "ffi.Array"}, {"data", json::Array{}}}, + json::Object{{"type", "None"}}, + json::Object{{"type", "test.Func"}, + {"data", json::Object{{"params", 0}, {"body", 1}, {"comment", 2}}}}}}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(fb), expected_fb)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_fb), fb)); +} + +TEST(Serialization, AttachMetadata) { + bool value = true; + json::Object metadata{{"version", "1.0"}}; + json::Object expected = + json::Object{{"root_index", 0}, + {"nodes", json::Array{json::Object{{"type", "bool"}, {"data", true}}}}, + {"metadata", metadata}}; + EXPECT_TRUE(StructuralEqual()(ToJSONGraph(value, metadata), expected)); + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected), value)); +} + +TEST(Serialization, ShuffleNodeOrder) { + // the FromJSONGraph is agnostic to the node order + // so we can shuffle the node order as it reads nodes lazily + Map duplicated_map{{"b", 42}, {"a", 42}}; + json::Object expected_shuffled = json::Object{ + {"root_index", 0}, + {"nodes", json::Array{ + json::Object{{"type", "ffi.Map"}, {"data", json::Array{2, 3, 1, 3}}}, + json::Object{{"type", "ffi.String"}, {"data", "a"}}, + json::Object{{"type", "ffi.String"}, {"data", "b"}}, + json::Object{{"type", "int"}, {"data", 42}}, + }}}; + EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_shuffled), duplicated_map)); +} + +} // namespace diff --git a/ffi/tests/cpp/extra/test_structural_equal_hash.cc b/ffi/tests/cpp/extra/test_structural_equal_hash.cc index 76c485d9062e..8a377f483713 100644 --- a/ffi/tests/cpp/extra/test_structural_equal_hash.cc +++ b/ffi/tests/cpp/extra/test_structural_equal_hash.cc @@ -147,10 +147,10 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) { TVar x = TVar("x"); TVar y = TVar("y"); // comment fields are ignored - TFunc fa = TFunc({x}, {TInt(1), x}, "comment a"); - TFunc fb = TFunc({y}, {TInt(1), y}, "comment b"); + TFunc fa = TFunc({x}, {TInt(1), x}, String("comment a")); + TFunc fb = TFunc({y}, {TInt(1), y}, String("comment b")); - TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, "comment c"); + TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, String("comment c")); EXPECT_TRUE(StructuralEqual()(fa, fb)); EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb)); diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index 78ca008e1094..c5725da941f5 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -21,6 +21,7 @@ #define TVM_FFI_TESTING_OBJECT_H_ #include +#include #include #include #include @@ -87,6 +88,15 @@ inline void TIntObj::RegisterReflection() { refl::TypeAttrDef() .def("test.GetValue", &TIntObj::GetValue) .attr("test.size", sizeof(TIntObj)); + // custom json serialization + refl::TypeAttrDef() + .def("__data_to_json__", + [](const TIntObj* self) -> Map { + return Map{{"value", self->value}}; + }) + .def("__data_from_json__", [](Map json_obj) -> TInt { + return TInt(json_obj["value"].cast()); + }); } class TFloatObj : public TNumberObj { @@ -154,6 +164,8 @@ class TVarObj : public Object { public: std::string name; + // need default constructor for json serialization + TVarObj() = default; TVarObj(std::string name) : name(name) {} static void RegisterReflection() { @@ -178,9 +190,11 @@ class TFuncObj : public Object { public: Array params; Array body; - String comment; + Optional comment; - TFuncObj(Array params, Array body, String comment) + // need default constructor for json serialization + TFuncObj() = default; + TFuncObj(Array params, Array body, Optional comment) : params(params), body(body), comment(comment) {} static void RegisterReflection() { @@ -198,7 +212,7 @@ class TFuncObj : public Object { class TFunc : public ObjectRef { public: - explicit TFunc(Array params, Array body, String comment) { + explicit TFunc(Array params, Array body, Optional comment) { data_ = make_object(params, body, comment); }