forked from mlc-ai/mlc-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrating MLC runtime with the new compilation workflow (mlc-ai#1203)
- Loading branch information
Showing
10 changed files
with
282 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#ifndef MLC_LLM_CPP_JSON_PARSER_H_ | ||
#define MLC_LLM_CPP_JSON_PARSER_H_ | ||
|
||
#define PICOJSON_USE_INT64 | ||
#ifndef __STDC_FORMAT_MACROS | ||
#define __STDC_FORMAT_MACROS | ||
#endif | ||
|
||
#include <picojson.h> | ||
#include <tvm/runtime/container/shape_tuple.h> | ||
#include <tvm/runtime/data_type.h> | ||
#include <tvm/runtime/logging.h> | ||
|
||
namespace mlc { | ||
namespace llm { | ||
namespace json { | ||
|
||
template <typename ValueType> | ||
inline ValueType Lookup(const picojson::object& json, const std::string& key) { | ||
auto it = json.find(key); | ||
CHECK(it != json.end()) << "ValueError: key `" << key << "` not found in the JSON object"; | ||
CHECK(it->second.is<ValueType>()) << "ValueError: key `" << key << "` has unexpected type"; | ||
return it->second.get<ValueType>(); | ||
} | ||
|
||
template <> | ||
inline tvm::runtime::DataType Lookup(const picojson::object& json, const std::string& key) { | ||
return tvm::runtime::DataType(tvm::runtime::String2DLDataType(Lookup<std::string>(json, key))); | ||
} | ||
|
||
template <> | ||
inline tvm::runtime::ShapeTuple Lookup(const picojson::object& json, const std::string& key) { | ||
picojson::array shape = Lookup<picojson::array>(json, key); | ||
std::vector<int64_t> result; | ||
result.reserve(shape.size()); | ||
for (const picojson::value& dim : shape) { | ||
CHECK(dim.is<int64_t>()) << "ValueError: key `" << key << "` has unexpected type"; | ||
result.push_back(dim.get<int64_t>()); | ||
} | ||
return tvm::runtime::ShapeTuple(std::move(result)); | ||
} | ||
|
||
inline picojson::object ParseObject(const std::string& json_str) { | ||
picojson::value result; | ||
std::string err = picojson::parse(result, json_str); | ||
if (!err.empty()) { | ||
LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << json_str; | ||
} | ||
CHECK(result.is<picojson::object>()) | ||
<< "ValueError: The given string is not a JSON object: " << json_str; | ||
return result.get<picojson::object>(); | ||
} | ||
|
||
inline picojson::object AsJSONObject(const picojson::value& json) { | ||
CHECK(json.is<picojson::object>()) << "ValueError: The given value is not a JSON object"; | ||
return json.get<picojson::object>(); | ||
} | ||
|
||
} // namespace json | ||
} // namespace llm | ||
} // namespace mlc | ||
|
||
#endif // MLC_LLM_CPP_JSON_PARSER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#include "./model_metadata.h" | ||
|
||
#include <tvm/runtime/packed_func.h> | ||
|
||
#include "./json_parser.h" | ||
|
||
namespace mlc { | ||
namespace llm { | ||
|
||
using namespace tvm::runtime; | ||
|
||
ModelMetadata::Param ModelMetadata::Param::FromJSON(const picojson::object& param) { | ||
Param result; | ||
result.name = json::Lookup<std::string>(param, "name"); | ||
result.shape = json::Lookup<ShapeTuple>(param, "shape"); | ||
result.dtype = json::Lookup<DataType>(param, "dtype"); | ||
return result; | ||
} | ||
|
||
ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata) { | ||
ModelMetadata result; | ||
result.model_type = json::Lookup<std::string>(metadata, "model_type"); | ||
result.quantization = json::Lookup<std::string>(metadata, "quantization"); | ||
picojson::array params = json::Lookup<picojson::array>(metadata, "params"); | ||
result.params.reserve(params.size()); | ||
for (const picojson::value& json_param : params) { | ||
result.params.emplace_back(ModelMetadata::Param::FromJSON(json::AsJSONObject(json_param))); | ||
} | ||
return result; | ||
} | ||
|
||
ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module) { | ||
std::string json_str = ""; | ||
try { | ||
TypedPackedFunc<String()> pf = module.GetFunction("_metadata"); | ||
ICHECK(pf != nullptr); | ||
json_str = pf(); | ||
} catch (...) { | ||
return ModelMetadata(); // TODO: add a warning message about legacy usecases | ||
} | ||
picojson::object json = json::ParseObject(json_str); | ||
try { | ||
return ModelMetadata::FromJSON(json); | ||
} catch (const std::exception& e) { | ||
LOG(WARNING) << "Failed to parse metadata:\n" << json_str; | ||
throw e; | ||
} | ||
} | ||
|
||
} // namespace llm | ||
} // namespace mlc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/*! | ||
* \file model_metadata.h | ||
* \brief Metadata stored in model lib | ||
*/ | ||
#include <tvm/runtime/container/shape_tuple.h> | ||
#include <tvm/runtime/container/string.h> | ||
#include <tvm/runtime/data_type.h> | ||
#include <tvm/runtime/module.h> | ||
|
||
#include <unordered_map> | ||
|
||
namespace picojson { | ||
class value; | ||
using object = std::unordered_map<std::string, value>; | ||
} // namespace picojson | ||
|
||
namespace mlc { | ||
namespace llm { | ||
|
||
struct ModelMetadata { | ||
struct Param { | ||
tvm::runtime::String name; | ||
tvm::runtime::ShapeTuple shape; | ||
tvm::runtime::DataType dtype; | ||
|
||
static Param FromJSON(const picojson::object& param_obj); | ||
}; | ||
std::string model_type; | ||
std::string quantization; | ||
std::vector<Param> params; | ||
|
||
static ModelMetadata FromJSON(const picojson::object& json_str); | ||
static ModelMetadata FromModule(tvm::runtime::Module module); | ||
}; | ||
|
||
} // namespace llm | ||
} // namespace mlc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.