Skip to content

Commit

Permalink
Support execution using pre-sharded weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Nov 6, 2023
1 parent 7b7d992 commit b85f723
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
20 changes: 17 additions & 3 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,25 @@ struct FunctionTable {
}
}

ObjectRef LoadParams(const std::string& model_path, Device device) {
ObjectRef LoadParams(const std::string& model_path, Device device, bool use_presharded_weights) {
if (this->use_disco) {
std::filesystem::path fs_model_path = model_path;
std::string metadata_path = (fs_model_path / "ndarray-cache.json").string();
std::string ndarray_cache_metadata = LoadBytesFromFile(metadata_path);
PackedFunc loader_create = this->get_global_func("runtime.disco.ShardLoader");
PackedFunc loader_load_all = this->get_global_func("runtime.disco.ShardLoaderLoadAll");

auto load_all_func_name = use_presharded_weights
? "runtime.disco.ShardLoaderLoadAllPresharded"
: "runtime.disco.ShardLoaderLoadAll";
PackedFunc loader_load_all = this->get_global_func(load_all_func_name);
CHECK(loader_create != nullptr);
CHECK(loader_load_all != nullptr);
DRef loader = loader_create(metadata_path, ndarray_cache_metadata, "", this->disco_mod);
DRef params = loader_load_all(loader);
return params;
} else {
CHECK(!use_presharded_weights) << "Use of pre-sharded weights requires more than one GPU";

const PackedFunc* fload_cache = tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.load");
ICHECK(fload_cache) << "TVM runtime cannot find vm.builtin.ndarray_cache.load";
(*fload_cache)(model_path, static_cast<int32_t>(device.device_type), device.device_id);
Expand Down Expand Up @@ -387,6 +393,12 @@ class LLMChat {
} else {
this->num_shards_ = 1;
}
if (config.count("use_presharded_weights")) {
CHECK(config["use_presharded_weights"].is<bool>());
this->use_presharded_weights_ = config["use_presharded_weights"].get<bool>();
} else {
this->use_presharded_weights_ = false;
}
if (config.count("max_window_size")) {
CHECK(config["max_window_size"].is<int64_t>());
this->max_window_size_ =
Expand Down Expand Up @@ -512,7 +524,7 @@ class LLMChat {
<< "Cannot find env function vm.builtin.sample_top_p_from_logits";
fsample_topp_from_logits_ = *fsample_topp_from_logits_ptr;
// Step 5. Load params in nd-array cache.
this->params_ = ft_.LoadParams(model_path, device_);
this->params_ = ft_.LoadParams(model_path, device_, use_presharded_weights_);
// Step 6. KV cache creation.
this->kv_cache_ = ft_.create_kv_cache_func_();
// Step 7. Pre-allocate fixed size ndarray
Expand Down Expand Up @@ -1357,6 +1369,8 @@ class LLMChat {
int64_t vocab_size_;
// number of shards in distributed inference
int64_t num_shards_;
// Load weights that were saved in sharded form
bool use_presharded_weights_;
// shift window fill factor
double shift_fill_factor_{0.3};
// temperature
Expand Down
3 changes: 3 additions & 0 deletions python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class ChatConfig: # pylint: disable=too-many-instance-attributes
Name of the model (e.g. ``Llama-2-7b-chat-hf``).
num_shards: Optional[str]
Tensor parallel degree.
use_presharded_weights: Optional[bool]
If True, the weights were saved with sharding already applied.
max_window_size: Optional[str]
Maximum kv cache window size.
"""
Expand All @@ -169,6 +171,7 @@ class ChatConfig: # pylint: disable=too-many-instance-attributes
model_category: Optional[str] = None
model_name: Optional[str] = None
num_shards: Optional[int] = None
use_presharded_weights: Optional[bool] = None
max_window_size: Optional[int] = None

@classmethod
Expand Down

0 comments on commit b85f723

Please sign in to comment.