|
4 | 4 | #include "llama.h" |
5 | 5 |
|
6 | 6 | #include <ctime> |
| 7 | +#include <cstdio> |
7 | 8 | #include <algorithm> |
8 | 9 |
|
9 | 10 | #if defined(_MSC_VER) |
@@ -70,6 +71,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu |
70 | 71 | } |
71 | 72 | } |
72 | 73 |
|
| 74 | +// plain, pipe-friendly output: one embedding per line |
| 75 | +static void print_raw_embeddings(const float * emb, |
| 76 | + int n_embd_count, |
| 77 | + int n_embd, |
| 78 | + const llama_model * model, |
| 79 | + enum llama_pooling_type pooling_type, |
| 80 | + int embd_normalize) { |
| 81 | + const uint32_t n_cls_out = llama_model_n_cls_out(model); |
| 82 | + const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); |
| 83 | + const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd; |
| 84 | + |
| 85 | + for (int j = 0; j < n_embd_count; ++j) { |
| 86 | + for (int i = 0; i < cols; ++i) { |
| 87 | + if (embd_normalize == 0) { |
| 88 | + printf("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); |
| 89 | + } else { |
| 90 | + printf("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); |
| 91 | + } |
| 92 | + } |
| 93 | + printf("\n"); |
| 94 | + } |
| 95 | +} |
| 96 | + |
73 | 97 | int main(int argc, char ** argv) { |
74 | 98 | common_params params; |
75 | 99 |
|
@@ -259,6 +283,10 @@ int main(int argc, char ** argv) { |
259 | 283 | float * out = emb + e * n_embd; |
260 | 284 | batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); |
261 | 285 |
|
| 286 | + if (params.embd_out == "raw") { |
| 287 | + print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize); |
| 288 | + } |
| 289 | + |
262 | 290 | if (params.embd_out.empty()) { |
263 | 291 | LOG("\n"); |
264 | 292 |
|
|
0 commit comments