Skip to content

Commit

Permalink
Integrating MLC runtime with the new compilation workflow (mlc-ai#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Nov 6, 2023
1 parent 3413d17 commit 7ccb51a
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 59 deletions.
63 changes: 63 additions & 0 deletions cpp/json_parser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#ifndef MLC_LLM_CPP_JSON_PARSER_H_
#define MLC_LLM_CPP_JSON_PARSER_H_

#define PICOJSON_USE_INT64
#ifndef __STDC_FORMAT_MACROS
#define __STDC_FORMAT_MACROS
#endif

#include <picojson.h>
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>

namespace mlc {
namespace llm {
namespace json {

template <typename ValueType>
inline ValueType Lookup(const picojson::object& json, const std::string& key) {
auto it = json.find(key);
CHECK(it != json.end()) << "ValueError: key `" << key << "` not found in the JSON object";
CHECK(it->second.is<ValueType>()) << "ValueError: key `" << key << "` has unexpected type";
return it->second.get<ValueType>();
}

template <>
inline tvm::runtime::DataType Lookup(const picojson::object& json, const std::string& key) {
return tvm::runtime::DataType(tvm::runtime::String2DLDataType(Lookup<std::string>(json, key)));
}

template <>
inline tvm::runtime::ShapeTuple Lookup(const picojson::object& json, const std::string& key) {
picojson::array shape = Lookup<picojson::array>(json, key);
std::vector<int64_t> result;
result.reserve(shape.size());
for (const picojson::value& dim : shape) {
CHECK(dim.is<int64_t>()) << "ValueError: key `" << key << "` has unexpected type";
result.push_back(dim.get<int64_t>());
}
return tvm::runtime::ShapeTuple(std::move(result));
}

inline picojson::object ParseObject(const std::string& json_str) {
picojson::value result;
std::string err = picojson::parse(result, json_str);
if (!err.empty()) {
LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << json_str;
}
CHECK(result.is<picojson::object>())
<< "ValueError: The given string is not a JSON object: " << json_str;
return result.get<picojson::object>();
}

inline picojson::object AsJSONObject(const picojson::value& json) {
CHECK(json.is<picojson::object>()) << "ValueError: The given value is not a JSON object";
return json.get<picojson::object>();
}

} // namespace json
} // namespace llm
} // namespace mlc

#endif // MLC_LLM_CPP_JSON_PARSER_H_
34 changes: 29 additions & 5 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <vector>

#include "conversation.h"
#include "model_metadata.h"
#include "random.h"
#include "support.h"
#include "tokenizers.h"
Expand Down Expand Up @@ -161,13 +162,18 @@ struct FunctionTable {
static_cast<int>(relax_vm::AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
static_cast<int>(relax_vm::AllocatorType::kPooled));
this->mod_get_func = [this](const std::string& name) -> PackedFunc {
return this->local_vm->GetFunction(name, false);
PackedFunc func = this->local_vm->GetFunction(name, false);
if (func == nullptr) {
LOG(WARNING) << "Cannot find function in VM: " << name;
}
return func;
};
this->get_global_func = [](const std::string& name) -> PackedFunc {
const auto* f = tvm::runtime::Registry::Get(name);
CHECK(f != nullptr) << "ValueError: Cannot find function " << name;
return *f;
};
this->model_metadata_ = ModelMetadata::FromModule(this->local_vm);
this->_InitFunctions();
}
}
Expand All @@ -188,10 +194,23 @@ struct FunctionTable {
const PackedFunc* fload_cache = tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.load");
ICHECK(fload_cache) << "TVM runtime cannot find vm.builtin.ndarray_cache.load";
(*fload_cache)(model_path, static_cast<int32_t>(device.device_type), device.device_id);
const PackedFunc* fload_params =
tvm::runtime::Registry::Get("vm.builtin.param_array_from_cache");
ICHECK(fload_params) << "Cannot find env function vm.builtin.param_array_from_cache";
Array<NDArray> params = (*fload_params)("param", -1);
Array<NDArray> params;
if (this->model_metadata_.params.empty()) {
constexpr const char* name_loader = "vm.builtin.param_array_from_cache";
const PackedFunc* fload_params = tvm::runtime::Registry::Get(name_loader);
ICHECK(fload_params) << "Cannot find env function: " << name_loader;
params = (*fload_params)("param", -1);
} else {
constexpr const char* name_loader = "vm.builtin.param_array_from_cache_by_name";
const PackedFunc* fload_params = tvm::runtime::Registry::Get(name_loader);
ICHECK(fload_params) << "Cannot find env function: " << name_loader;
Array<String> param_names;
param_names.reserve(this->model_metadata_.params.size());
for (const auto& param : this->model_metadata_.params) {
param_names.push_back(param.name);
}
params = (*fload_params)(param_names);
}
// after we get params, it is safe to simply clear the cached version
// as these params are referenced by params_
const PackedFunc* fclear_ndarray_cache =
Expand All @@ -210,6 +229,9 @@ struct FunctionTable {
this->softmax_func_ = mod_get_func("softmax_with_temperature");
this->encoding_without_cache_func_ = mod_get_func("encoding_without_cache");
this->create_kv_cache_func_ = mod_get_func("create_kv_cache");
if (this->create_kv_cache_func_ == nullptr) {
this->create_kv_cache_func_ = mod_get_func("_initialize_effect");
}
this->reset_kv_cache_func_ = mod_get_func("reset_kv_cache");
if (this->reset_kv_cache_func_ == nullptr) {
this->reset_kv_cache_func_ = get_global_func("vm.builtin.attention_kv_cache_array_clear");
Expand Down Expand Up @@ -260,6 +282,7 @@ struct FunctionTable {
PackedFunc reset_kv_cache_func_;
bool support_backtracking_kv_;
PackedFunc fkvcache_array_popn_;
ModelMetadata model_metadata_;
};

} // namespace
Expand Down Expand Up @@ -437,6 +460,7 @@ class LLMChat {
* \note This function overrides existing configurations.
*/
void LoadJSONOverride(const std::string& config_str, bool partial_update = false) {
LOG(INFO) << "config_str = " << config_str;
picojson::value config_json;
std::string err = picojson::parse(config_json, config_str);
if (!err.empty()) {
Expand Down
51 changes: 51 additions & 0 deletions cpp/model_metadata.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "./model_metadata.h"

#include <tvm/runtime/packed_func.h>

#include "./json_parser.h"

namespace mlc {
namespace llm {

using namespace tvm::runtime;

ModelMetadata::Param ModelMetadata::Param::FromJSON(const picojson::object& param) {
Param result;
result.name = json::Lookup<std::string>(param, "name");
result.shape = json::Lookup<ShapeTuple>(param, "shape");
result.dtype = json::Lookup<DataType>(param, "dtype");
return result;
}

ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata) {
ModelMetadata result;
result.model_type = json::Lookup<std::string>(metadata, "model_type");
result.quantization = json::Lookup<std::string>(metadata, "quantization");
picojson::array params = json::Lookup<picojson::array>(metadata, "params");
result.params.reserve(params.size());
for (const picojson::value& json_param : params) {
result.params.emplace_back(ModelMetadata::Param::FromJSON(json::AsJSONObject(json_param)));
}
return result;
}

ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module) {
std::string json_str = "";
try {
TypedPackedFunc<String()> pf = module.GetFunction("_metadata");
ICHECK(pf != nullptr);
json_str = pf();
} catch (...) {
return ModelMetadata(); // TODO: add a warning message about legacy usecases
}
picojson::object json = json::ParseObject(json_str);
try {
return ModelMetadata::FromJSON(json);
} catch (const std::exception& e) {
LOG(WARNING) << "Failed to parse metadata:\n" << json_str;
throw e;
}
}

} // namespace llm
} // namespace mlc
37 changes: 37 additions & 0 deletions cpp/model_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*!
* \file model_metadata.h
* \brief Metadata stored in model lib
*/
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/module.h>

#include <unordered_map>

namespace picojson {
class value;
using object = std::unordered_map<std::string, value>;
} // namespace picojson

namespace mlc {
namespace llm {

struct ModelMetadata {
struct Param {
tvm::runtime::String name;
tvm::runtime::ShapeTuple shape;
tvm::runtime::DataType dtype;

static Param FromJSON(const picojson::object& param_obj);
};
std::string model_type;
std::string quantization;
std::vector<Param> params;

static ModelMetadata FromJSON(const picojson::object& json_str);
static ModelMetadata FromModule(tvm::runtime::Module module);
};

} // namespace llm
} // namespace mlc
48 changes: 45 additions & 3 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Python entrypoint of compilation."""
import dataclasses
import json
import logging
from io import StringIO
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple

from tvm import IRModule, relax
from tvm.relax.frontend import nn
from tvm.target import Target

from ..support.style import bold
Expand Down Expand Up @@ -46,21 +48,61 @@ def display(self) -> None:
print(out.getvalue().rstrip())


def _attach_auxiliary_methods(
mod: IRModule,
named_params: List[Tuple[str, nn.Parameter]],
args: CompileArgs,
model_config,
) -> None:
def _metadata():
metadata = {
"quantization": args.quantization.name,
"model_type": args.model.name,
"params": [
{
"name": name,
"shape": list(param.shape),
"dtype": param.dtype,
}
for name, param in named_params
],
}
bb = relax.BlockBuilder() # pylint: disable=invalid-name
with bb.function("main", params=[]):
bb.emit_func_output(relax.StringImm(json.dumps(metadata)))
return bb.get()["main"]

def _attach_variable_bounds():
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
mod[g_var] = func.with_attr(
"tir_var_upper_bound",
{
"seq_len": model_config.max_sequence_length,
"total_seq_len": model_config.max_sequence_length,
},
)

mod["_metadata"] = _metadata()
_attach_variable_bounds()


def _compile(args: CompileArgs):
logger.info("Creating model from: %s", args.config)
model_config = args.model.config.from_file(args.config)
args.overrides.apply(model_config)
model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization)
logger.info("Exporting the model to TVM Unity compiler")
mod, _named_params = model.export_tvm(
mod, named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
)
_attach_auxiliary_methods(mod, named_params, args, model_config)
logger.info("Running optimizations using TVM Unity")
with args.target:
mod = relax.get_pipeline("mlc_llm")(mod)
logger.info("Generating code using TVM Unity")
args.build_func(mod, args)
logger.info("Code dumped to: %s", bold(str(args.output)))
logger.info("Generated: %s", bold(str(args.output)))


def compile( # pylint: disable=too-many-arguments,redefined-builtin
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""A compiler pass that fuses decode + matmul + elementwise."""
"""A compiler pass that fuses dequantize + matmul + elementwise."""
import tvm
from tvm import IRModule, relax
from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard


@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise")
class FuseDecodeMatmulEwise: # pylint: disable=too-few-public-methods
"""A compiler pass that fuses decode + matmul + elementwise."""
@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeMatmulEwise")
class FuseDequantizeMatmulEwise: # pylint: disable=too-few-public-methods
"""A compiler pass that fuses dequantize + matmul + elementwise."""

def transform_module(
self,
Expand All @@ -23,7 +23,7 @@ def transform_module(
relax.transform.FuseOpsByPattern(
[
(
"decode_matmul",
"dequantize_matmul",
*_pattern(match_ewise, n_aux_tensor),
)
]
Expand Down Expand Up @@ -62,7 +62,9 @@ def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool:
g_var = call.args[0]
if not isinstance(g_var, relax.GlobalVar):
return False
return g_var.name_hint.startswith("decode") or g_var.name_hint.startswith("fused_decode")
return g_var.name_hint.startswith("dequantize") or g_var.name_hint.startswith(
"fused_dequantize"
)

def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool:
call = ctx.annotated_expr["matmul"]
Expand Down
Loading

0 comments on commit 7ccb51a

Please sign in to comment.