diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7abd..dfed3dd431e 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -156,6 +156,17 @@ extern "C" { size_t (*read)(void * ctx, void * output, size_t read_size); bool (*eof)(void * ctx); void (*close)(void * ctx); + // Optional: skip forward by offset bytes. + // If NULL, skipping is not supported. + // Returns true on success, false on failure. + bool (*skip)(void * ctx, size_t offset); + // Optional: seek to absolute position in the file. + // If NULL, absolute seeking is not supported. + // Returns true on success, false on failure. + bool (*seek)(void * ctx, size_t offset); + // Optional: get current position in the file. + // If NULL, position tracking is not supported. + size_t (*tell)(void * ctx); } whisper_model_loader; // grammar element type diff --git a/src/whisper.cpp b/src/whisper.cpp index 082e7619e07..392bb9490db 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -1684,6 +1684,71 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; + // Pre-scan tensor metadata from file to determine actual types + // This allows us to allocate device memory with the correct sizes + struct tensor_meta { + ggml_type type; + int32_t ne[4]; + }; + std::map tensor_type_map; + size_t tensor_start_offset = 0; // file offset where tensor section begins + + // If loader supports skip, seek, and tell, scan tensor metadata first (without loading data) + if (loader->skip && loader->seek && loader->tell) { + // Remember the current position (start of tensors section) + tensor_start_offset = loader->tell(loader->context); + + while (true) { + int32_t n_dims; + + read_safe(loader, n_dims); + + // Check for EOF after reading the first field + if (loader->eof(loader->context)) { + break; + } + + int32_t length; + int32_t ttype; + + read_safe(loader, length); + read_safe(loader, ttype); + + tensor_meta meta; + meta.type = ggml_type(ttype); + meta.ne[0] = 1; + meta.ne[1] = 1; + meta.ne[2] = 1; + meta.ne[3] = 1; + + int32_t nelements = 1; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, meta.ne[i]); + nelements *= meta.ne[i]; + } + + std::string name; + std::vector tmp(length); + loader->read(loader->context, &tmp[0], tmp.size()); + name.assign(&tmp[0], tmp.size()); + + // Calculate tensor data size and skip it (without loading into memory) + const size_t tensor_data_size = ggml_row_size(meta.type, meta.ne[0]) * (nelements / meta.ne[0]); + if (!loader->skip(loader->context, tensor_data_size)) { + WHISPER_LOG_ERROR("%s: failed to skip tensor data for '%s'\n", __func__, name.c_str()); + return false; + } + + tensor_type_map[name] = meta; + } + + // Seek back to the start of tensors section for the actual data loading later + if (!loader->seek(loader->context, tensor_start_offset)) { + WHISPER_LOG_ERROR("%s: failed to seek back to tensor data\n", __func__); + return false; + } + } + std::map ctx_map; auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); @@ -1712,6 +1777,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con buft_list_t buft_list = make_buft_list(wctx.params); auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * { + // Get the tensor name + std::string tensor_name = format(ASR_TENSOR_NAMES.at(system).at(type), layer); + + // If we pre-scanned tensor types, update meta tensor to use the actual type from file + auto it = tensor_type_map.find(tensor_name); + if (it != tensor_type_map.end()) { + const tensor_meta & file_meta = it->second; + if (meta->type != file_meta.type) { + // Update meta tensor type to match the file + meta->type = file_meta.type; + // Update strides based on new type + meta->nb[0] = ggml_type_size(meta->type); + meta->nb[1] = meta->nb[0] * (meta->ne[0] / ggml_blck_size(meta->type)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + meta->nb[i] = meta->nb[i-1] * meta->ne[i-1]; + } + } + } + ggml_op op = ASR_TENSOR_INFO.at(type); ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); if (!buft) { @@ -1721,7 +1805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con ggml_context * ctx = get_ctx(buft); ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); - model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor; + model.tensors[tensor_name] = tensor; return tensor; }; @@ -1892,14 +1976,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { - WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.c_str()); return false; } - auto tensor = model.tensors[name.data()]; + auto tensor = model.tensors[name]; if (ggml_nelements(tensor) != nelements) { - WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str()); WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); return false; @@ -1907,7 +1991,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); return false; } @@ -1915,18 +1999,26 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const size_t file_tensor_size = ggml_row_size(ggml_type(ttype), ne[0]) * (nelements / ne[0]); const size_t expected_tensor_size = ggml_nbytes(tensor); - // For mixed precision models, the tensor type in file may differ from the type - // the tensor was created with. We need to handle this carefully. + // If we pre-scanned types, the tensor type should already match + // Otherwise (loader doesn't support seek), we need to handle type mismatch here if (tensor->type != ggml_type(ttype)) { - // Mixed precision: tensor created with one type, file has another - // We need to update the tensor's type to match the file + // Type mismatch - this happens when loader doesn't support seek + // or when tensor wasn't found during pre-scan + if (!tensor_type_map.empty()) { + // We pre-scanned but types still don't match - this is unexpected + WHISPER_LOG_ERROR("%s: tensor '%s' type mismatch after pre-scan: expected %s, file has %s\n", + __func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype))); + return false; + } + + // Loader doesn't support seek - handle type mismatch at runtime (legacy path) WHISPER_LOG_DEBUG("%s: tensor '%s' type mismatch (expected %s, file has %s)\n", - __func__, name.data(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype))); + __func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype))); // Check if the allocated buffer is large enough for the file's data if (file_tensor_size > expected_tensor_size) { WHISPER_LOG_ERROR("%s: tensor '%s' buffer too small: allocated %zu bytes for %s, but file needs %zu bytes for %s\n", - __func__, name.data(), expected_tensor_size, ggml_type_name(tensor->type), + __func__, name.c_str(), expected_tensor_size, ggml_type_name(tensor->type), file_tensor_size, ggml_type_name(ggml_type(ttype))); return false; } @@ -1941,10 +2033,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con tensor->nb[i] = tensor->nb[i-1] * tensor->ne[i-1]; } } else { - // Normal case: types match, verify size + // Types match, verify size if (file_tensor_size != expected_tensor_size) { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), expected_tensor_size, file_tensor_size); + __func__, name.c_str(), expected_tensor_size, file_tensor_size); return false; } } @@ -3689,6 +3781,25 @@ struct whisper_context * whisper_init_from_file_with_params_no_state(const char fin->close(); }; + loader.skip = [](void * ctx, size_t offset) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->seekg(offset, std::ios::cur); + return fin->good(); + }; + + loader.seek = [](void * ctx, size_t offset) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->clear(); // clear any error flags + fin->seekg(offset, std::ios::beg); + return fin->good(); + }; + + loader.tell = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + auto pos = fin->tellg(); + return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast(pos); + }; + auto ctx = whisper_init_with_params_no_state(&loader, params); if (ctx) { @@ -3732,6 +3843,29 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu loader.close = [](void * /*ctx*/) { }; + loader.skip = [](void * ctx, size_t offset) { + buf_context * buf = reinterpret_cast(ctx); + if (buf->current_offset + offset > buf->size) { + return false; + } + buf->current_offset += offset; + return true; + }; + + loader.seek = [](void * ctx, size_t offset) { + buf_context * buf = reinterpret_cast(ctx); + if (offset > buf->size) { + return false; + } + buf->current_offset = offset; + return true; + }; + + loader.tell = [](void * ctx) { + buf_context * buf = reinterpret_cast(ctx); + return buf->current_offset; + }; + return whisper_init_with_params_no_state(&loader, params); } @@ -4782,6 +4916,25 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params( fin->close(); }; + loader.skip = [](void * ctx, size_t offset) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->seekg(offset, std::ios::cur); + return fin->good(); + }; + + loader.seek = [](void * ctx, size_t offset) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->clear(); + fin->seekg(offset, std::ios::beg); + return fin->good(); + }; + + loader.tell = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + auto pos = fin->tellg(); + return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast(pos); + }; + auto ctx = whisper_vad_init_with_params(&loader, params); if (!ctx) { whisper_vad_free(ctx);