diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 01ec22aa3cc28..093e769e338f3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2494,7 +2494,7 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vocab_size (4096) # TODO: Fix + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_uint32("snac.quantizer.codebook_size", self.hparams["codebook_size"]) self.gguf_writer.add_uint32("snac.quantizer.codebook_dim", self.hparams["codebook_dim"]) self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) # 1024 diff --git a/examples/tts/orpheus-tts.cpp b/examples/tts/orpheus-tts.cpp index 45595e9552fc0..a7f0e16dfa296 100644 --- a/examples/tts/orpheus-tts.cpp +++ b/examples/tts/orpheus-tts.cpp @@ -1,6 +1,5 @@ #include "common.h" #include "llama.h" -#include "llama-impl.h" #include "log.h" #include "arg.h" #include "sampling.h" @@ -19,148 +18,30 @@ #include #include -std::vector redistribute_codes(const std::vector& raw_codes) { - std::vector snac_codes; - for (size_t i = 0; i < raw_codes.size(); i += 7) { - // Ensure we have a full frame (7 codes) - if (i + 6 >= raw_codes.size()) break; - - // Frame offsets (per notebook) - snac_codes.push_back(raw_codes[i]); // Codebook 0 (no offset) - snac_codes.push_back(raw_codes[i+1] - 4096); // Codebook 1 - snac_codes.push_back(raw_codes[i+2] - 8192); // Codebook 2 - snac_codes.push_back(raw_codes[i+3] - 12288); // Codebook 2 - snac_codes.push_back(raw_codes[i+4] - 16384); // Codebook 1 - snac_codes.push_back(raw_codes[i+5] - 20480); // Codebook 2 - snac_codes.push_back(raw_codes[i+6] - 24576); // Codebook 2 - } - return snac_codes; -} - -static std::vector embd_to_audio( - const float * embd, - const int n_codes, - const int n_embd, - const int n_thread); -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); -static void fill_hann_window(int length, bool periodic, float * output); -static void irfft(int n, const float * inp_cplx, float * out_real); -static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output); - -static void print_usage(int /*argc*/, char **argv) { - LOG("\nexample usage:\n"); - LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); - LOG("\n"); -} - -static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { - auto tmp = common_tokenize(vocab, txt, add_special, parse_special); - prompt.insert(prompt.end(), tmp.begin(), tmp.end()); -} - - -// // Include embd_to_audio and save_wav16 from tts.cpp (for now) -static std::vector embd_to_audio( - const float * embd, - const int n_codes, - const int n_embd, - const int n_thread) { - const int n_fft = 1280; - const int n_hop = 320; - const int n_win = 1280; - const int n_pad = (n_win - n_hop)/2; - const int n_out = (n_codes - 1)*n_hop + n_win; - - std::vector hann(n_fft); - fill_hann_window(hann.size(), true, hann.data()); - - int n_spec = n_embd*n_codes; - - std::vector E (n_spec); - std::vector S (n_spec); - std::vector ST(n_spec); - - for (int l = 0; l < n_codes; ++l) { - for (int k = 0; k < n_embd; ++k) { - E[k*n_codes + l] = embd[l*n_embd + k]; - } - } - - for (int k = 0; k < n_embd/2; ++k) { - for (int l = 0; l < n_codes; ++l) { - float mag = E[(k )*n_codes + l]; - float phi = E[(k + n_embd/2)*n_codes + l]; - mag = exp(mag); - if (mag > 1e2) { - mag = 1e2; - } - S[2*(k*n_codes + l) + 0] = mag*cosf(phi); - S[2*(k*n_codes + l) + 1] = mag*sinf(phi); - } - } - - for (int l = 0; l < n_codes; ++l) { - for (int k = 0; k < n_embd/2; ++k) { - ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; - ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; - } - } - - std::vector res (n_codes*n_fft); - std::vector hann2(n_codes*n_fft); - - std::vector workers(n_thread); - for (int i = 0; i < n_thread; ++i) { - workers[i] = std::thread([&, i]() { - for (int l = i; l < n_codes; l += n_thread) { - irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); - for (int j = 0; j < n_fft; ++j) { - res [l*n_fft + j] *= hann[j]; - hann2[l*n_fft + j] = hann[j] * hann[j]; - } - } - }); - } - for (int i = 0; i < n_thread; ++i) { - workers[i].join(); - } - - std::vector audio; - std::vector env; - - fold(res, n_out, n_win, n_hop, n_pad, audio); - fold(hann2, n_out, n_win, n_hop, n_pad, env); - - for (size_t i = 0; i < audio.size(); ++i) { - audio[i] /= env[i]; - } - - return audio; -} - -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +static bool save_wav16(const std::string &fname, const std::vector &data, int sample_rate) { std::ofstream file(fname, std::ios::binary); if (!file) { LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); return false; } - struct wav_header { - char riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t chunk_size; - char wave[4] = {'W', 'A', 'V', 'E'}; - char fmt[4] = {'f', 'm', 't', ' '}; - uint32_t fmt_chunk_size = 16; - uint16_t audio_format = 1; // PCM - uint16_t num_channels = 1; // Mono - uint32_t sample_rate; - uint32_t byte_rate; - uint16_t block_align; - uint16_t bits_per_sample = 16; - char data[4] = {'d', 'a', 't', 'a'}; - uint32_t data_size; - } header; - + wav_header header; header.sample_rate = sample_rate; header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); header.block_align = header.num_channels * (header.bits_per_sample / 8); @@ -169,95 +50,49 @@ static bool save_wav16(const std::string & fname, const std::vector & dat file.write(reinterpret_cast(&header), sizeof(header)); - for (const auto & sample : data) { - int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + for (const auto &sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0f, -32768.0f, 32767.0f)); file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); } return file.good(); } -// Supporting functions from tts.cpp (for embd_to_audio) -static void fill_hann_window(int length, bool periodic, float * output) { - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } +std::vector redistribute_codes(const std::vector& raw_codes) { + std::vector snac_codes; + for (size_t i = 0; i < raw_codes.size(); i += 7) { + if (i + 6 >= raw_codes.size()) break; + + // Subtract 128266 base and layer-specific offsets + snac_codes.push_back(raw_codes[i] - 128266); // Layer 1: offset 0 + snac_codes.push_back(raw_codes[i + 1] - 128266 - 4096); // Layer 2: offset 4096 + snac_codes.push_back(raw_codes[i + 2] - 128266 - 8192); // Layer 3: offset 8192 + snac_codes.push_back(raw_codes[i + 3] - 128266 - 12288); // Layer 3: offset 12288 + snac_codes.push_back(raw_codes[i + 4] - 128266 - 16384); // Layer 2: offset 16384 + snac_codes.push_back(raw_codes[i + 5] - 128266 - 20480); // Layer 3: offset 20480 + snac_codes.push_back(raw_codes[i + 6] - 128266 - 24576); // Layer 3: offset 24576 + } + return snac_codes; } -static void twiddle(float * real, float * imag, int k, int N) { - float angle = 2 * M_PI * k / N; - *real = cos(angle); - *imag = sin(angle); -} - -static void irfft(int n, const float * inp_cplx, float * out_real) { - int N = n / 2 + 1; - - std::vector real_input(N); - std::vector imag_input(N); - for (int i = 0; i < N; ++i) { - real_input[i] = inp_cplx[2 * i]; - imag_input[i] = inp_cplx[2 * i + 1]; - } - - std::vector real_output(n); - std::vector imag_output(n); - - for (int k = 0; k < n; ++k) { - real_output[k] = 0.0f; - imag_output[k] = 0.0f; - for (int m = 0; m < N; ++m) { - float twiddle_real; - float twiddle_imag; - - twiddle(&twiddle_real, &twiddle_imag, k * m, n); - - real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; - imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; - } - } - - for (int i = 0; i < n; ++i) { - out_real[i] = real_output[i] / N; - } +static void print_usage(int /*argc*/, char **argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); + LOG("\n"); } -static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output) { - int64_t output_height = n_out; - int64_t kernel_w = n_win; - int64_t stride_w = n_hop; - int64_t width = n_out; - - output.resize(width, 0.0f); - - int64_t col_idx = 0; - for (int64_t w_col = 0; w_col < width; ++w_col) { - int64_t start = w_col * stride_w - n_pad; - int64_t end = start + kernel_w; - - for (int64_t w_im = start; w_im < end; ++w_im) { - if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { - output[w_im] += data[col_idx]; - } - col_idx++; - } - } - - output.resize(n_out - 2 * n_pad); +static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt.insert(prompt.end(), tmp.begin(), tmp.end()); } int main(int argc, char **argv) { common_params params; - + params.model = "models/orpheus-3b-0.1-ft-q4_k_m.gguf"; - params.vocoder.model = "models/snac-vocab.gguf"; + params.vocoder.model = "models/snac-fwd-pass-devel.gguf"; params.out_file = "output.wav"; - params.n_predict = 1200; params.sampling.top_k = 4; params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; params.n_batch = 4096; @@ -265,7 +100,8 @@ int main(int argc, char **argv) { common_init(); llama_backend_init(); llama_numa_init(params.numa); - + + common_init_result orpheus_init_ttc = common_init_from_params(params); llama_model * model_ttc = NULL; @@ -290,17 +126,15 @@ int main(int argc, char **argv) { prompt_add(tokens, vocab, "", false, true); // Emotion tag tokens.push_back(128009); // <|eot_id|> tokens.push_back(128260); // <|endofhuman|> - + llama_model * model_cts = NULL; llama_context * ctx_cts = NULL; params.model = params.vocoder.model; - params.n_batch = 2; params.embedding = true; - // disable warmup, SNAC doesn't care about BOS or EOS tokens; - params.warmup = false; + params.warmup = false; // SNAC doesn't care about BOS or EOS tokens common_init_result snac_init_cts = common_init_from_params(params); LOG_INF("SNAC model loaded: %s\n", params.model.c_str()); @@ -308,35 +142,80 @@ int main(int argc, char **argv) { model_cts = snac_init_cts.model.get(); ctx_cts = snac_init_cts.context.get(); - std::vector speech_codes = {100, 4200, 8500, 12500, 16500, 21000, 25000, - 200, 4300, 8600, 12600, 16600, 21111, 25100}; - - std::vector snac_codes = redistribute_codes(speech_codes); - - const int n_codes = speech_codes.size(); - const int batch_size = n_codes; - - llama_batch batch = llama_batch_init(batch_size, 0, 1); - - for (size_t i = 0; i < n_codes; ++i) { + // TODO: Use real orpheus codes + // Just some random numbers for testing + std::vector orpheus_codes = { + // Frame 1, 7 codes per frame + 128266 + 100, // L1: 100 + 128266 + 4096 + 200, // L2: 200 + 128266 + 8192 + 300, // L3: 300 + 128266 + 12288 + 400,// L3: 400 + 128266 + 16384 + 500,// L2: 500 + 128266 + 20480 + 600,// L3: 600 + 128266 + 24576 + 700,// L3: 700 + // Frame 2 + 128266 + 150, 128266 + 4096 + 250, 128266 + 8192 + 350, 128266 + 12288 + 450, + 128266 + 16384 + 550, 128266 + 20480 + 650, 128266 + 24576 + 750, + // Frame 3 + 128266 + 110, 128266 + 4096 + 210, 128266 + 8192 + 310, 128266 + 12288 + 410, + 128266 + 16384 + 510, 128266 + 20480 + 610, 128266 + 24576 + 710, + // Frame 4 + 128266 + 120, 128266 + 4096 + 220, 128266 + 8192 + 320, 128266 + 12288 + 420, + 128266 + 16384 + 520, 128266 + 20480 + 620, 128266 + 24576 + 720, + // Frame 5 + 128266 + 130, 128266 + 4096 + 230, 128266 + 8192 + 330, 128266 + 12288 + 430, + 128266 + 16384 + 530, 128266 + 20480 + 630, 128266 + 24576 + 730, + // Frame 6 + 128266 + 140, 128266 + 4096 + 240, 128266 + 8192 + 340, 128266 + 12288 + 440, + 128266 + 16384 + 540, 128266 + 20480 + 640, 128266 + 24576 + 740, + // Frame 7 + 128266 + 160, 128266 + 4096 + 260, 128266 + 8192 + 360, 128266 + 12288 + 460, + 128266 + 16384 + 560, 128266 + 20480 + 660, 128266 + 24576 + 760, + // Frame 8 + 128266 + 170, 128266 + 4096 + 270, 128266 + 8192 + 370, 128266 + 12288 + 470, + 128266 + 16384 + 570, 128266 + 20480 + 670, 128266 + 24576 + 770, + // Frame 9 + 128266 + 180, 128266 + 4096 + 280, 128266 + 8192 + 380, 128266 + 12288 + 480, + 128266 + 16384 + 580, 128266 + 20480 + 680, 128266 + 24576 + 780, + // Frame 10 + 128266 + 190, 128266 + 4096 + 290, 128266 + 8192 + 390, 128266 + 12288 + 490, + 128266 + 16384 + 590, 128266 + 20480 + 690, 128266 + 24576 + 790 + }; + + std::vector snac_codes = redistribute_codes(orpheus_codes); + + const int batch_size = snac_codes.size(); + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + + for (size_t i = 0; i < batch_size; ++i) { common_batch_add(batch, snac_codes[i], i, {0}, true); } LOG_INF("Batch before decode: n_tokens = %d\n", batch.n_tokens); - if (llama_decode(ctx_cts, batch) != 0) { /* error */ } - - if (llama_decode(ctx_cts, batch) != 0) { /* error */ } - GGML_ASSERT(batch.n_tokens == n_codes); + GGML_ASSERT(batch.n_tokens == batch_size); batch.logits[batch.n_tokens - 1] = true; - + if (llama_decode(ctx_cts, batch) != 0) { LOG_ERR("Failed to decode SNAC batch\n"); return 1; } - llama_synchronize(ctx_cts); - LOG_INF("SNAC decode completed\n"); + llama_synchronize(ctx_cts); + + float* embd = llama_get_embeddings(ctx_cts); + if (!embd) { + LOG_ERR("No embeddings available\n"); + return 1; + } + + int n_samples = llama_get_n_outputs(ctx_cts); + std::vector audio(n_samples); + LOG_INF("n_samples: %i\n", n_samples); + memcpy(audio.data(), embd, n_samples * sizeof(float)); + + save_wav16(params.out_file, audio, 24000); llama_batch_free(batch); llama_backend_free(); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index def6eb3423c61..7bded06f88a94 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -14894,6 +14894,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: + case GGML_OP_SNAKE: { n_tasks = 1; } break; diff --git a/include/llama.h b/include/llama.h index 6a44be404d914..f98f1910bcf1c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -629,6 +629,8 @@ extern "C" { llama_seq_id * cells_sequences; }; + LLAMA_API int32_t llama_get_n_outputs(struct llama_context * ctx); + // Create an empty KV cache view. (use only for debugging purposes) LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5bec63e2e79ff..d15061655da39 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -851,6 +851,10 @@ float * llama_context::get_logits_ith(int32_t i) { } } +int32_t llama_context::get_n_outputs() { + return n_outputs; +} + float * llama_context::get_embeddings() { // reorder embeddings for backward compatibility output_reorder(); @@ -1403,10 +1407,21 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT(embd != nullptr); float * embd_out = embd + n_outputs_prev*n_embd; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (model.arch == LLM_ARCH_SNAC_DEC) { + // TODO: hack, SNAC outputs audio samples, not embeddings + // Rely on n_outputs for now, but perhaps add an `n_samples_snac` to + // llama_context to avoid doing these checks + int64_t n_samples = t_embd->ne[0]; + if (n_samples > 0) { + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_samples * sizeof(float)); + n_outputs = n_samples; // Update for downstream + } + } else { + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs * n_embd * sizeof(float)); + } } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -1471,8 +1486,11 @@ int llama_context::decode(llama_batch & inp_batch) { } } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - n_outputs = n_outputs_all; + // TODO: Hack for now to avoid overwriting n_outputs in previous step + if (model.arch != LLM_ARCH_SNAC_DEC) { + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + } // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); @@ -2417,6 +2435,12 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } +int32_t llama_get_n_outputs(struct llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_n_outputs(); +} + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); diff --git a/src/llama-context.h b/src/llama-context.h index 04facb544cb1a..ff9ad663d1fe5 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -48,6 +48,8 @@ struct llama_context { float * get_logits(); float * get_logits_ith(int32_t i); + int32_t get_n_outputs(); + float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bee6e6bd359b4..4051c42852039 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1319,13 +1319,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_SNAC_DEC: { - hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; // From decoder_channel_dims + // TODO: Read from GGUF + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; hparams.upsample_rates = {8, 8, 4, 2}; hparams.n_embd = 768; hparams.n_layer = 8; - // Dummy KV cache params to satisfy llama.cpp - for (uint32_t i = 0; i < 7; ++i) { // n_total_layers = 8 + // Dummy KV cache params to satisfy init error + for (uint32_t i = 0; i < hparams.n_layer; ++i) { hparams.n_head_arr[i] = 1; hparams.n_head_kv_arr[i] = 1; } @@ -3716,8 +3717,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {8, 4096, 1}, 0); - hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; - // Quantizer projection tensors (0, 1, 2) for (int qid = 0; qid < 3; ++qid) { fprintf(stderr, "%s: Loading quantizer %d tensors\n", __func__, qid); @@ -3782,49 +3781,49 @@ bool llama_model::load_tensors(llama_model_loader & ml) { break; case 3: // Block 3: Residual Unit 1 { - int res_unit_idx = 0; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); } break; case 4: // Block 4: Residual Unit 2 { - int res_unit_idx = 1; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); } break; case 5: // Block 5: Residual Unit 3 { - int res_unit_idx = 2; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); } break; default: fprintf(stderr, "%s: ERROR: Unexpected block id %d in layer %d\n", __func__, bid, i); - return false; // Or handle error appropriately + return false; } fprintf(stderr, "%s: Layer %d, Block %d: Finished\n", __func__, i, bid); - } // End block loop + } } else if (i == 6) { // --- Layer 6: Alpha --- layer.alpha = create_tensor(tn(LLM_TENSOR_ALPHA, i, -1), {1, n_in, 1}, 0); @@ -3834,9 +3833,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S7, i, -1), {1, 1, n_out}, 0); layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B7, i, -1), {n_out}, 0); } - else { // Should not happen + else { fprintf(stderr, "%s: ERROR: Unexpected layer index %d\n", __func__, i); - return false; // Or handle error appropriately + return false; } fprintf(stderr, "%s: Layer %d: Finished\n", __func__, i); } @@ -11744,286 +11743,230 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; -// struct llm_build_snac_dec : public llm_graph_context { - -// llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { -// LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); -// for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { -// LLAMA_LOG_INFO("%d ", ubatch.token[i]); -// } -// LLAMA_LOG("\n"); -// LLAMA_LOG_DEBUG("%s: Entering constructor, model.layers.size() = %zu\n", __func__, model.layers.size()); -// ggml_tensor * cur; -// ggml_tensor * inpL; - -// // TODO: probalby just get raw codes -// //cur = build_inp_embd(model.tok_embd); -// //LLAMA_LOG_INFO("After build_inp_embd: shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // hack, hardcode expected SNAC input at first conv layer -// cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); // [channels, seq_len, 1, 1] -// ggml_set_input(cur); -// LLAMA_LOG_INFO("hardcoded shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // end hack - -// // Log input tokens before processing -// LLAMA_LOG_INFO("%s: ubatch.n_tokens = %u\n", __func__, ubatch.n_tokens); -// LLAMA_LOG_WARN("%s: Input tokens from ubatch = ", __func__); -// for (uint32_t i = 0; i < ubatch.n_tokens && i < 20; ++i) { -// LLAMA_LOG_INFO("%d ", ubatch.token[i]); -// } -// if (ubatch.n_tokens > 20) LLAMA_LOG_INFO("..."); -// LLAMA_LOG("\n"); - -// // ggml_tensor * layer_1; -// // ggml_tensor * layer_2; -// // ggml_tensor * layer_3; -// //redistribute_codes(cur, &layer_1, &layer_2, &layer_3); - -// // Log the redistributed layers -// //log_tensor("Layer 1", layer_1); -// //log_tensor("Layer 2", layer_2); -// //log_tensor("Layer 3", layer_3); - -// for (uint32_t il = 1; il < model.layers.size(); ++il) { -// const auto & layer = model.layers[il]; - -// LLAMA_LOG_DEBUG("%s: Layer %u: Starting, cur = %p\n", __func__, il, cur); - -// if (il == 1) { // pointwise -// LLAMA_LOG_INFO("%s: Layer %u: Pointwise conv, conv_w = %p, conv_s = %p, conv_b = %p\n", -// __func__, il, layer.conv_w, layer.conv_s, layer.conv_b); -// LLAMA_LOG_INFO("Before transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// cur = ggml_transpose(ctx0, cur); // [768, 512] -> [512, 768] -// LLAMA_LOG_INFO("After transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 0); -// LLAMA_LOG_INFO("%s: Layer %u: After pointwise conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } else if (il == model.layers.size() - 1) { -// LLAMA_LOG_INFO("%s: Layer %u: Final layer, alpha = %p, conv_w = %p, conv_s = %p, conv_b = %p\n", -// __func__, il, layer.alpha, layer.conv_w, layer.conv_s, layer.conv_b); -// cur = ggml_snake(ctx0, cur, layer.alpha); -// LLAMA_LOG_INFO("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 3); -// LLAMA_LOG_INFO("%s: Layer %u: After final conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = ggml_tanh(ctx0, cur); -// LLAMA_LOG_INFO("%s: Layer %u: After ggml_tanh, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } else { -// // Layers 2-5: Decoder Blocks (1024 -> 512 -> 256 -> 128 -> 64) -// const int stride = hparams.upsample_rates[il - 2]; // 8 for il = 2 -// const int padding = stride; - -// // Block 0: Snake activation -// const auto & block0 = layer.decoder_blocks[0]; -// LLAMA_LOG_DEBUG("%s: Layer %u: Block 0, alpha = %p\n", __func__, il, block0.alpha); -// cur = ggml_snake(ctx0, cur, block0.alpha); -// LLAMA_LOG_DEBUG("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // Block 1: Transposed convolution -// const auto & block1 = layer.decoder_blocks[1]; -// LLAMA_LOG_DEBUG("%s: Layer %u: Block 1, stride = %d, up_weight = %p, up_scale = %p, up_bias = %p\n", -// __func__, il, stride, block1.up_weight, block1.up_scale, block1.up_bias); - -// cur = apply_conv1d_transpose(cur, block1.up_weight, block1.up_scale, block1.up_bias, stride, padding); -// LLAMA_LOG_DEBUG("%s: Layer %u: After conv1d_transpose, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // Residual Units (3 per block) -// for (int j = 0; j < 3; ++j) { -// const auto & ru = block1.res_units[j]; -// ggml_tensor * inpL = cur; -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Starting, inpL = %p, alpha1 = %p, conv1_w = %p, conv1_s = %p, conv1_b = %p\n", -// __func__, il, j, inpL, ru.alpha1, ru.conv1_w, ru.conv1_s, ru.conv1_b); - -// cur = ggml_snake(ctx0, cur, ru.alpha1); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; -// int padding = 3 * dilation; // Kernel 7, dilated padding = (7-1)/2 * dilation -// cur = apply_conv1d(cur, ru.conv1_w, ru.conv1_s, ru.conv1_b, 1, padding); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); - -// // pw -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Pointwise, alpha2 = %p, conv2_w = %p, conv2_s = %p, conv2_b = %p\n", -// __func__, il, j, ru.alpha2, ru.conv2_w, ru.conv2_s, ru.conv2_b); -// cur = ggml_snake(ctx0, cur, ru.alpha2); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = apply_conv1d(cur, ru.conv2_w, ru.conv2_s, ru.conv2_b, 1, 0); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); - -// // residual -// cur = ggml_add(ctx0, cur, inpL); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_add, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } -// } -// LLAMA_LOG_DEBUG("%s: Layer %u: Finished, cur = %p\n", __func__, il, cur); -// } - -// int64_t target_samples = 24000; // TODO: magic number -// LLAMA_LOG_DEBUG("%s: Trimming output, cur = %p, target_samples = %ld, cur->ne[0] = %ld\n", -// __func__, cur, target_samples, cur ? cur->ne[0] : -1); -// if (cur->ne[0] > target_samples) { -// cur = ggml_get_rows(ctx0, cur, ggml_new_i32(ctx0, target_samples)); -// LLAMA_LOG_DEBUG("%s: After ggml_get_rows, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } - -// LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); -// cb(cur, "result_embd", -1); -// res->t_embd = cur; - -// LLAMA_LOG_DEBUG("%s: Building forward graph, cur = %p\n", __func__, cur); -// ggml_build_forward_expand(gf, cur); -// LLAMA_LOG_DEBUG("%s: Graph build completed\n", __func__); -// } - -// // TODO: move these somewhere else -// private: -// // Helper to log tensor contents -// void log_tensor(const char * name, ggml_tensor * tensor) { -// if (!tensor) { -// LLAMA_LOG_INFO("%s: %s is null\n", __func__, name); -// return; -// } -// LLAMA_LOG_DEBUG("%s: %s shape = [%ld, %ld, %ld, %ld], first 20 elements = ", -// __func__, name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); -// int n_elements = ggml_nelements(tensor); -// float * data = (float *)tensor->data; -// for (int i = 0; i < std::min(20, n_elements); ++i) { -// LLAMA_LOG_DEBUG("%.2f ", data[i]); -// } -// if (n_elements > 20) LLAMA_LOG_DEBUG("..."); -// LLAMA_LOG_DEBUG("\n"); -// } - -// void redistribute_codes(ggml_tensor * input, ggml_tensor ** layer_1, ggml_tensor ** layer_2, ggml_tensor ** layer_3) { -// int64_t n_codes = input->ne[1]; // Assuming input is [n_embd, n_tokens, 1, 1] -// int64_t n_frames = n_codes / 7; -// if (n_codes % 7 != 0) { -// LLAMA_LOG_ERROR("%s: Input codes length %ld is not a multiple of 7\n", __func__, n_codes); -// *layer_1 = *layer_2 = *layer_3 = nullptr; -// return; -// } - -// int64_t n_layer_1 = n_frames; // 1 code per frame -// int64_t n_layer_2 = n_frames * 2; // 2 codes per frame -// int64_t n_layer_3 = n_frames * 4; // 4 codes per frame - -// // Indices for each layer -// std::vector idx_layer_1(n_layer_1); -// std::vector idx_layer_2(n_layer_2); -// std::vector idx_layer_3(n_layer_3); - -// for (int64_t i = 0; i < n_frames; ++i) { -// int64_t base_idx = i * 7; -// idx_layer_1[i] = base_idx + 0; // No offset -// idx_layer_2[i * 2] = base_idx + 1; // Offset -4096 -// idx_layer_2[i * 2 + 1] = base_idx + 4; // Offset -16384 -// idx_layer_3[i * 4] = base_idx + 2; // Offset -8192 -// idx_layer_3[i * 4 + 1] = base_idx + 3; // Offset -12288 -// idx_layer_3[i * 4 + 2] = base_idx + 5; // Offset -20480 -// idx_layer_3[i * 4 + 3] = base_idx + 6; // Offset -24576 -// } - -// // Create index tensors -// ggml_tensor * idx_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); -// ggml_tensor * idx_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); -// ggml_tensor * idx_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); - -// memcpy(idx_1->data, idx_layer_1.data(), n_layer_1 * sizeof(int32_t)); -// memcpy(idx_2->data, idx_layer_2.data(), n_layer_2 * sizeof(int32_t)); -// memcpy(idx_3->data, idx_layer_3.data(), n_layer_3 * sizeof(int32_t)); - -// // Extract layers using ggml_get_rows -// *layer_1 = ggml_get_rows(ctx0, input, idx_1); -// *layer_2 = ggml_get_rows(ctx0, input, idx_2); -// *layer_3 = ggml_get_rows(ctx0, input, idx_3); - -// // Apply offsets -// *layer_2 = ggml_add(ctx0, *layer_2, ggml_new_f32(ctx0, -4096.0f)); // Simplified; we'll refine offsets later -// *layer_3 = ggml_add(ctx0, *layer_3, ggml_new_f32(ctx0, -8192.0f)); // Simplified for now -// } - -// ggml_tensor * apply_conv1d(ggml_tensor * input, ggml_tensor * conv_w, ggml_tensor * conv_scale, ggml_tensor * conv_b, -// int stride, int padding) { -// ggml_tensor * w_final = normalize_weight(conv_w, conv_scale); -// ggml_tensor * cur = ggml_conv_1d_ph(ctx0, w_final, input, stride, padding); -// if (conv_b) { -// ggml_tensor* bias_reshaped = ggml_reshape_3d(ctx0, conv_b, 1, 1024, 1); -// cur = ggml_add(ctx0, cur, bias_reshaped); -// } -// return cur; -// } - -// ggml_tensor * apply_conv1d_transpose(ggml_tensor * input, ggml_tensor * up_weight, ggml_tensor * up_scale, ggml_tensor * up_bias, int stride, int padding) { -// // Normalize weights (temporary fix for up_scale shape mismatch) -// if (up_scale->ne[2] != up_weight->ne[1]) { // 1024 != 512 -// LLAMA_LOG_WARN("up_scale channels (%ld) don’t match output channels (%ld), expected behavior may vary\n", up_scale->ne[2], up_weight->ne[1]); -// // Ideally reshape up_scale to [1, 1, 512, 1], but no reshape; proceed with warning -// } -// ggml_tensor * w_final = normalize_weight(up_weight, up_scale); -// LLAMA_LOG_INFO("After normalize weight: w_final shape = [%ld, %ld, %ld, %ld]\n", -// w_final->ne[0], w_final->ne[1], w_final->ne[2], w_final->ne[3]); - -// ggml_tensor * cur = ggml_conv_transpose_1d(ctx0, w_final, input, stride, 0, 1); -// LLAMA_LOG_INFO("After ggml_conv_transpose_1d = [%ld, %ld, %ld, %ld]\n", -// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// if (up_bias) { -// // up_bias is [512, 1, 1, 1]; need [4104, 512, 1, 1] for ggml_add -// LLAMA_LOG_INFO("entering up_bias block. Before ggml_repeat, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// LLAMA_LOG_INFO("Before ggml_repeat, up_bias shape = [%ld, %ld, %ld, %ld]\n", up_bias->ne[0], up_bias->ne[1], up_bias->ne[2], up_bias->ne[3]); -// ggml_tensor * bias_repeated = ggml_repeat(ctx0, up_bias, cur); -// LLAMA_LOG_DEBUG("Repeated up_bias to shape = [%ld, %ld, %ld, %ld]\n", -// bias_repeated->ne[0], bias_repeated->ne[1], bias_repeated->ne[2], bias_repeated->ne[3]); -// cur = ggml_add(ctx0, cur, bias_repeated); -// LLAMA_LOG_DEBUG("After bias add: cur shape = [%ld, %ld, %ld, %ld]\n", -// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// } -// return cur; -// } - -// // w_final = scale * (w / || w ||) -// ggml_tensor * normalize_weight(ggml_tensor * w, ggml_tensor * scale) { -// ggml_tensor * norm = ggml_norm(ctx0, w, 1e-5f); // 1e-8f ? -// ggml_tensor * w_normalized = ggml_div(ctx0, w, norm); -// ggml_tensor * w_final = ggml_mul(ctx0, w_normalized, scale); -// return w_final; -// } -// }; - // TODO: Placeholder struct llm_build_snac_dec : public llm_graph_context { llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * emb_layer_1, * emb_layer_2, * emb_layer_3; + build_codebook_embd(model, &emb_layer_1, &emb_layer_2, &emb_layer_3); + + if (emb_layer_1 == nullptr || emb_layer_2 == nullptr || emb_layer_3 == nullptr) { + // graph build is called with garbage ubatch codes during model init + // in this case, bypass normal graph construction and return a dummy + LLAMA_LOG_INFO("build_codebook_inputs returned null, using dummy tensor\n"); + cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, ubatch.n_tokens > 0 ? ubatch.n_tokens : 64, 1, 1); + ggml_set_input(cur); + } else { + // Projections + cur = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[0], 8, 768), emb_layer_1); + cur = ggml_reshape_4d(ctx0, cur, 768, emb_layer_1->ne[1], 1, 1); + ggml_tensor * scale_1 = ggml_reshape_4d(ctx0, model.codebook_proj_s[0], 768, 1, 1, 1); + cur = ggml_mul(ctx0, cur, scale_1); + ggml_tensor * bias_1 = ggml_reshape_4d(ctx0, model.codebook_proj_b[0], 768, 1, 1, 1); // Fix here + cur = ggml_add(ctx0, cur, bias_1); + + ggml_tensor * proj_2 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[1], 8, 768), emb_layer_2); + proj_2 = ggml_reshape_4d(ctx0, proj_2, 768, emb_layer_2->ne[1], 1, 1); + ggml_tensor * scale_2 = ggml_reshape_4d(ctx0, model.codebook_proj_s[1], 768, 1, 1, 1); + proj_2 = ggml_mul(ctx0, proj_2, scale_2); + ggml_tensor * bias_2 = ggml_reshape_4d(ctx0, model.codebook_proj_b[1], 768, 1, 1, 1); + proj_2 = ggml_add(ctx0, proj_2, bias_2); + + ggml_tensor * proj_3 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[2], 8, 768), emb_layer_3); + proj_3 = ggml_reshape_4d(ctx0, proj_3, 768, emb_layer_3->ne[1], 1, 1); + ggml_tensor * scale_3 = ggml_reshape_4d(ctx0, model.codebook_proj_s[2], 768, 1, 1, 1); + proj_3 = ggml_mul(ctx0, proj_3, scale_3); + ggml_tensor * bias_3 = ggml_reshape_4d(ctx0, model.codebook_proj_b[2], 768, 1, 1, 1); + proj_3 = ggml_add(ctx0, proj_3, bias_3); + + cur = ggml_concat(ctx0, cur, proj_2, 1); + cur = ggml_concat(ctx0, cur, proj_3, 1); + + for (int j = 1; j <= hparams.n_layer; ++j) { + const auto & layer = model.layers[j]; + const int64_t n_in = hparams.n_channels[j-1]; + const int64_t n_out = (j < 7) ? hparams.n_channels[j] : hparams.n_channels[j-1]; + + if (j == 1) { + int64_t seq_len = cur->ne[1]; + cur = ggml_reshape_2d(ctx0, cur, 768, seq_len); // cur starts F32 (type 0) from projections + ggml_tensor * w = ggml_reshape_2d(ctx0, layer.conv_w, 768, 1024); // F16 (type 1) + ggml_tensor * s = ggml_cpy(ctx0, layer.conv_s, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, 1, n_out)); // Cast F32 -> F16 + w = ggml_mul(ctx0, w, s); + cur = ggml_mul_mat(ctx0, w, cur); + cur = ggml_reshape_4d(ctx0, cur, seq_len, 1024, 1, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + // Residual Units + else if (j >= 2 && j <= 5) { + ggml_tensor * alpha = layer.decoder_blocks[0].alpha; + cur = ggml_snake(ctx0, cur, alpha); + + ggml_tensor * w = layer.decoder_blocks[1].up_weight; + ggml_tensor * s = ggml_cpy(ctx0, layer.decoder_blocks[1].up_scale, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_in, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_transpose_1d(ctx0, w, cur, hparams.upsample_rates[j-2], 0, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.decoder_blocks[1].up_bias, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + + ggml_tensor * noise_w = layer.decoder_blocks[2].noise_w; + ggml_tensor * noise_s = ggml_cpy(ctx0, layer.decoder_blocks[2].noise_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + noise_w = ggml_mul(ctx0, noise_w, noise_s); + cur = ggml_conv_1d(ctx0, noise_w, cur, 1, 0, 1); + + for (int r = 0; r < 3; ++r) { + int bid = 3 + r; + ggml_tensor * w1 = layer.decoder_blocks[bid].res_unit.conv1_w; + ggml_tensor * s1 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv1_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w1 = ggml_mul(ctx0, w1, s1); + cur = ggml_conv_1d_dw(ctx0, w1, cur, 1, 3, 1); + ggml_tensor * b1 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv1_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b1); + + ggml_tensor * w2 = layer.decoder_blocks[bid].res_unit.conv2_w; + ggml_tensor * s2 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv2_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w2 = ggml_mul(ctx0, w2, s2); + cur = ggml_conv_1d(ctx0, w2, cur, 1, 0, 1); + ggml_tensor * b2 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv2_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b2); + } + } + else if (j == 6) { + ggml_tensor * alpha = layer.alpha; + cur = ggml_snake(ctx0, cur, alpha); + } + else if (j == 7) { + ggml_tensor * w = layer.conv_w; + ggml_tensor * s = layer.conv_s; + + s = ggml_reshape_4d(ctx0, s, 1, 1, 1, 1); + s = ggml_cpy(ctx0, s, ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, 1, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_1d(ctx0, w, cur, 1, 3, 1); + + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + } - // TODO: Remove - LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); - for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { - LLAMA_LOG_INFO("%d ", ubatch.token[i]); } - LLAMA_LOG("\n"); - ggml_tensor * cur; - // TODO: Hack. Implement codebook lookups and out_proj - cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); - ggml_set_input(cur); - // end hack + cur = ggml_cpy(ctx0, cur, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3])); - LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); cb(cur, "result_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } +private: + // TODO: SNAC expects a multilayered input from 3 different embedding matrices + void build_codebook_embd(const llama_model & model, + ggml_tensor ** emb_layer_1, + ggml_tensor ** emb_layer_2, + ggml_tensor ** emb_layer_3) { + + *emb_layer_1 = nullptr; + *emb_layer_2 = nullptr; + *emb_layer_3 = nullptr; + + + + bool is_initialized = (ubatch.token != nullptr && ubatch.n_tokens > 0); + if (is_initialized) { + for (int i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.token[i] < 0 || ubatch.token[i] >= 4096) { + is_initialized = false; + break; + } + } + } + + if (!is_initialized) { + return; + } + + int32_t n_tokens = ubatch.n_tokens; + int32_t n_frames = n_tokens / 7; + if (n_tokens % 7 != 0) { + LLAMA_LOG_INFO("build_codebook_embd: n_tokens (%d) not a multiple of 7, truncating\n", n_tokens); + n_frames = n_tokens / 7; + } + + // TODO: read from vq_strides + int32_t n_layer_1 = n_frames; + int32_t n_layer_2 = n_frames * 2; + int32_t n_layer_3 = n_frames * 4; + + LLAMA_LOG_INFO("build_codebook_embd: n_frames = %d, n_layer_1 = %d, n_layer_2 = %d, n_layer_3 = %d\n", + n_frames, n_layer_1, n_layer_2, n_layer_3); + + std::vector idx_1_data(n_layer_1); + std::vector idx_2_data(n_layer_2); + std::vector idx_3_data(n_layer_3); + + // map codes to respective codebook + for (int32_t i = 0; i < n_frames; ++i) { + int32_t base_idx = i * 7; + idx_1_data[i] = ubatch.token[base_idx + 0]; + idx_2_data[i * 2] = ubatch.token[base_idx + 1]; + idx_2_data[i * 2 + 1] = ubatch.token[base_idx + 4]; + idx_3_data[i * 4] = ubatch.token[base_idx + 2]; + idx_3_data[i * 4 + 1] = ubatch.token[base_idx + 3]; + idx_3_data[i * 4 + 2] = ubatch.token[base_idx + 5]; + idx_3_data[i * 4 + 3] = ubatch.token[base_idx + 6]; + } + + // Tensors used for codebook lookups + ggml_tensor * idx_layer_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); + ggml_tensor * idx_layer_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); + ggml_tensor * idx_layer_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); + + if (!idx_layer_1 || !idx_layer_2 || !idx_layer_3) { + LLAMA_LOG_INFO("build_codebook_embd: Failed to allocate index tensors\n"); + return; + } + + // ggml is lazy, so explicitly create buffers for codes to be placed in idx_layer_N + ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); + if (!cpu_buft) { + LLAMA_LOG_ERROR("build_codebook_embd: Failed to get CPU buffer type\n"); + return; + } + + ggml_backend_buffer_t buffer_1 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_1 * sizeof(int32_t)); + ggml_backend_buffer_t buffer_2 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_2 * sizeof(int32_t)); + ggml_backend_buffer_t buffer_3 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_3 * sizeof(int32_t)); + + if (!buffer_1 || !buffer_2 || !buffer_3) { + LLAMA_LOG_ERROR("build_codebook_embd: Failed to allocate backend buffers\n"); + if (buffer_1) ggml_backend_buffer_free(buffer_1); + if (buffer_2) ggml_backend_buffer_free(buffer_2); + if (buffer_3) ggml_backend_buffer_free(buffer_3); + return; + } + + // move codes to idx_layer_N + idx_layer_1->buffer = buffer_1; + idx_layer_2->buffer = buffer_2; + idx_layer_3->buffer = buffer_3; + + idx_layer_1->data = ggml_backend_buffer_get_base(buffer_1); + idx_layer_2->data = ggml_backend_buffer_get_base(buffer_2); + idx_layer_3->data = ggml_backend_buffer_get_base(buffer_3); + + ggml_backend_tensor_set(idx_layer_1, idx_1_data.data(), 0, n_layer_1 * sizeof(int32_t)); + ggml_backend_tensor_set(idx_layer_2, idx_2_data.data(), 0, n_layer_2 * sizeof(int32_t)); + ggml_backend_tensor_set(idx_layer_3, idx_3_data.data(), 0, n_layer_3 * sizeof(int32_t)); + + *emb_layer_1 = ggml_get_rows(ctx0, model.codebook[0], idx_layer_1); + *emb_layer_2 = ggml_get_rows(ctx0, model.codebook[1], idx_layer_2); + *emb_layer_3 = ggml_get_rows(ctx0, model.codebook[2], idx_layer_3); + } }; llama_memory_i * llama_model::create_memory() const { diff --git a/src/llama-model.h b/src/llama-model.h index 5e636b0b3b3f3..e75bcf1ed8887 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -156,7 +156,7 @@ struct llama_layer_snac_dec_block { struct ggml_tensor * conv2_w = nullptr; struct ggml_tensor * conv2_s = nullptr; struct ggml_tensor * conv2_b = nullptr; - } res_units[3]; + } res_unit; }; struct llama_layer { @@ -328,7 +328,7 @@ struct llama_layer { struct llama_layer_convnext convnext; struct ggml_tensor * conv_w = nullptr; - struct ggml_tensor * conv_s = nullptr; + struct ggml_tensor * conv_s = nullptr; struct ggml_tensor * conv_b = nullptr; struct ggml_tensor * alpha = nullptr;