Skip to content

Commit 7f59af5

Browse files
AzeirahSlyEcho
authored andcommitted
Steer with inpSA instead of with inpL
Signed-off-by: Henri Vasserman <henv@hot.ee>
1 parent 1b0ff2c commit 7f59af5

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

examples/main/main.cpp

+12-13
Original file line numberDiff line numberDiff line change
@@ -176,28 +176,27 @@ int main(int argc, char ** argv) {
176176

177177
if (!params.steering_add.empty() || !params.steering_sub.empty())
178178
{
179-
params.steering_add.insert(0, 1, ' ');
180-
params.steering_sub.insert(0, 1, ' ');
181-
182179
auto add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
183180
auto sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);
184181

185-
//if (add_tokens.size() != sub_tokens.size()) {
186-
// while (add_tokens.size() < sub_tokens.size()) {
187-
// add_tokens.push_back(llama_token_nl());
188-
// }
189-
// while (sub_tokens.size() < add_tokens.size()) {
190-
// sub_tokens.push_back(llama_token_nl());
191-
// }
192-
//}
193-
//const int N = embd_inp.size();
182+
183+
if (add_tokens.size() != sub_tokens.size()) {
184+
while (add_tokens.size() < sub_tokens.size()) {
185+
add_tokens.push_back(llama_token_nl());
186+
}
187+
while (sub_tokens.size() < add_tokens.size()) {
188+
sub_tokens.push_back(llama_token_nl());
189+
}
190+
}
191+
194192
llama_set_steering_write(ctx, params.steering_source, +1.0f);
195193
llama_eval(ctx, add_tokens.data(), std::min((int)add_tokens.size(), n_ctx), 0, params.n_threads);
196194

197-
llama_set_steering_write(ctx, params.steering_layer, -1.0f);
195+
llama_set_steering_write(ctx, params.steering_source, -1.0f);
198196
llama_eval(ctx, sub_tokens.data(), std::min((int)sub_tokens.size(), n_ctx), 0, params.n_threads);
199197

200198
llama_set_steering_read(ctx, params.steering_layer, params.steering_mul);
199+
std::cout << "Steering: `" << params.steering_add << "` - `" << params.steering_sub << "` * " << params.steering_mul << "\n";
201200
}
202201

203202
// debug message about similarity of saved session, if applicable

llama.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <mutex>
3333
#include <sstream>
3434
#include <numeric>
35+
#include <iostream>
3536

3637
#define LLAMA_USE_SCRATCH
3738
#define LLAMA_MAX_SCRATCH_BUFFERS 16
@@ -1187,8 +1188,8 @@ static bool llama_eval_internal(
11871188
ggml_add(ctx0, ggml_scale(ctx0, inpL, scal), steer), steer));
11881189
break;
11891190
}
1190-
1191-
inpL = ggml_add(ctx0, ggml_scale(ctx0, steer, scal), inpL);
1191+
// std::cout << "\nAdding steering vector to inpL " << il << "\n";
1192+
inpSA = ggml_add(ctx0, ggml_scale(ctx0, steer, scal), inpSA);
11921193
}
11931194

11941195
// norm

0 commit comments

Comments
 (0)