Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
return res;
}

int err = llama_apply_adapter_cvec(
int err = llama_set_adapter_cvec(
lctx,
cvec.data.data(),
cvec.data.size(),
Expand Down Expand Up @@ -1344,12 +1344,15 @@ std::string get_model_endpoint() {
}

void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx);
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_set_adapter_lora(ctx, la.ptr, la.scale);
}
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;

for (auto & la: lora) {
loras.push_back(la.ptr);
scales.push_back(la.scale);
}

llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
}

struct llama_model_params common_model_params_to_llama(common_params & params) {
Expand Down
21 changes: 6 additions & 15 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,29 +656,20 @@ extern "C" {

// The following functions operate on a llama_context, hence the naming: llama_verb_...

// Add a loaded LoRA adapter to given context
// This will not modify model's weight
LLAMA_API int32_t llama_set_adapter_lora(
// Set LoRa adapters on the context. Will only modify if the adapters currently in context are different.
LLAMA_API int32_t llama_set_adapters_lora(
struct llama_context * ctx,
struct llama_adapter_lora * adapter,
float scale);

// Remove a specific LoRA adapter from given context
// Return -1 if the adapter is not present in the context
LLAMA_API int32_t llama_rm_adapter_lora(
struct llama_context * ctx,
struct llama_adapter_lora * adapter);

// Remove all LoRA adapters from given context
LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);
struct llama_adapter_lora ** adapters,
size_t n_adapters,
float * scales);

// Apply a loaded control vector to a llama_context, or if data is NULL, clear
// the currently loaded vector.
// n_embd should be the size of a single layer's control, and data should point
// to an n_embd x n_layers buffer starting from layer 1.
// il_start and il_end are the layer range the vector should apply to (both inclusive)
// See llama_control_vector_load in common to load a control vector.
LLAMA_API int32_t llama_apply_adapter_cvec(
LLAMA_API int32_t llama_set_adapter_cvec(
struct llama_context * ctx,
const float * data,
size_t len,
Expand Down
91 changes: 38 additions & 53 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,51 +1061,43 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
return true;
}

void llama_context::set_adapter_lora(
llama_adapter_lora * adapter,
float scale) {
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);

if (auto it = loras.find(adapter); it != loras.end()) {
if (it->second == scale) {
return;
}
if (adapters_lora_are_same(adapters, n_adapters, scales)) {
return;
}

loras[adapter] = scale;
loras.clear();

for (size_t i = 0; i < n_adapters; i ++) {
if (scales[i] != 0.0f) {
loras[adapters[i]] = scales[i];
}
}

sched_need_reserve = true;
}

bool llama_context::rm_adapter_lora(
llama_adapter_lora * adapter) {
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);

auto it = loras.find(adapter);
if (it != loras.end()) {
loras.erase(it);
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);

sched_need_reserve = true;

return true;
if (n_adapters != loras.size()) {
return false;
}

return false;
}

void llama_context::clear_adapter_lora() {
LLAMA_LOG_DEBUG("%s: call\n", __func__);
for (size_t i = 0; i < n_adapters; i ++) {
auto it = loras.find(adapters[i]);

if (loras.empty()) {
return;
if (it == loras.end() || it->second != scales[i]) {
return false;
}
}

loras.clear();

sched_need_reserve = true;
return true;
}

bool llama_context::apply_adapter_cvec(
bool llama_context::set_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
Expand Down Expand Up @@ -3222,35 +3214,28 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {

// llama adapter API

int32_t llama_set_adapter_lora(
int32_t llama_set_adapters_lora(
llama_context * ctx,
llama_adapter_lora * adapter,
float scale) {
ctx->set_adapter_lora(adapter, scale);

return 0;
}
llama_adapter_lora ** adapters,
size_t n_adapters,
float * scales) {
if (adapters == nullptr || scales == nullptr) {
GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call");
}

int32_t llama_rm_adapter_lora(
llama_context * ctx,
llama_adapter_lora * adapter) {
bool res = ctx->rm_adapter_lora(adapter);
ctx->set_adapters_lora(adapters, n_adapters, scales);

return res ? 0 : -1;
}

void llama_clear_adapter_lora(llama_context * ctx) {
ctx->clear_adapter_lora();
return 0;
}

int32_t llama_apply_adapter_cvec(
int32_t llama_set_adapter_cvec(
llama_context * ctx,
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end) {
bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end) {
bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end);

return res ? 0 : -1;
}
Expand Down
11 changes: 3 additions & 8 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,11 @@ struct llama_context {
void set_causal_attn(bool value);
void set_warmup(bool value);

void set_adapter_lora(
llama_adapter_lora * adapter,
float scale);
void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales);

bool rm_adapter_lora(
llama_adapter_lora * adapter);
bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales);

void clear_adapter_lora();

bool apply_adapter_cvec(
bool set_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
Expand Down
Loading