Skip to content

Commit

Permalink
Add compatibility with #801
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Apr 10, 2023
1 parent 4ae3167 commit 0d8999a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
3 changes: 2 additions & 1 deletion examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.lora_adapter = argv[i];
params.use_mmap = false;
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
Expand Down Expand Up @@ -254,7 +255,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
}
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");
Expand Down
10 changes: 8 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,12 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
ggml_context* lora_ctx = ggml_init(params);
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;

// create a name -> tensor map of the model to accelerate lookups
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
for (auto & kv: model.tensors_by_name) {
model_tensors.insert(kv);
}

fprintf(stderr, "%s: ", __func__);

// read tensors and apply
Expand Down Expand Up @@ -1839,7 +1845,7 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
base_name.erase(pos);
// fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());

if (model.tensors.find(base_name.data()) == model.tensors.end()) {
if (model_tensors.find(base_name.data()) == model_tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
return 1;
}
Expand Down Expand Up @@ -1878,7 +1884,7 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {

ggml_tensor * tensor = model.tensors[base_name];
ggml_tensor * tensor = model_tensors[base_name];
ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];

Expand Down

0 comments on commit 0d8999a

Please sign in to comment.