Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ extern "C" {
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
// Optional: skip forward by offset bytes.
// If NULL, skipping is not supported.
// Returns true on success, false on failure.
bool (*skip)(void * ctx, size_t offset);
// Optional: seek to absolute position in the file.
// If NULL, absolute seeking is not supported.
// Returns true on success, false on failure.
bool (*seek)(void * ctx, size_t offset);
// Optional: get current position in the file.
// If NULL, position tracking is not supported.
size_t (*tell)(void * ctx);
} whisper_model_loader;

// grammar element type
Expand Down
179 changes: 166 additions & 13 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,71 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con

const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;

// Pre-scan tensor metadata from file to determine actual types
// This allows us to allocate device memory with the correct sizes
struct tensor_meta {
ggml_type type;
int32_t ne[4];
};
std::map<std::string, tensor_meta> tensor_type_map;
size_t tensor_start_offset = 0; // file offset where tensor section begins

// If loader supports skip, seek, and tell, scan tensor metadata first (without loading data)
if (loader->skip && loader->seek && loader->tell) {
// Remember the current position (start of tensors section)
tensor_start_offset = loader->tell(loader->context);

while (true) {
int32_t n_dims;

read_safe(loader, n_dims);

// Check for EOF after reading the first field
if (loader->eof(loader->context)) {
break;
}

int32_t length;
int32_t ttype;

read_safe(loader, length);
read_safe(loader, ttype);

tensor_meta meta;
meta.type = ggml_type(ttype);
meta.ne[0] = 1;
meta.ne[1] = 1;
meta.ne[2] = 1;
meta.ne[3] = 1;

int32_t nelements = 1;
for (int i = 0; i < n_dims; ++i) {
read_safe(loader, meta.ne[i]);
nelements *= meta.ne[i];
}

std::string name;
std::vector<char> tmp(length);
loader->read(loader->context, &tmp[0], tmp.size());
name.assign(&tmp[0], tmp.size());

// Calculate tensor data size and skip it (without loading into memory)
const size_t tensor_data_size = ggml_row_size(meta.type, meta.ne[0]) * (nelements / meta.ne[0]);
if (!loader->skip(loader->context, tensor_data_size)) {
WHISPER_LOG_ERROR("%s: failed to skip tensor data for '%s'\n", __func__, name.c_str());
return false;
}

tensor_type_map[name] = meta;
}

// Seek back to the start of tensors section for the actual data loading later
if (!loader->seek(loader->context, tensor_start_offset)) {
WHISPER_LOG_ERROR("%s: failed to seek back to tensor data\n", __func__);
return false;
}
}

std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
Expand Down Expand Up @@ -1712,6 +1777,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
buft_list_t buft_list = make_buft_list(wctx.params);

auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
// Get the tensor name
std::string tensor_name = format(ASR_TENSOR_NAMES.at(system).at(type), layer);

// If we pre-scanned tensor types, update meta tensor to use the actual type from file
auto it = tensor_type_map.find(tensor_name);
if (it != tensor_type_map.end()) {
const tensor_meta & file_meta = it->second;
if (meta->type != file_meta.type) {
// Update meta tensor type to match the file
meta->type = file_meta.type;
// Update strides based on new type
meta->nb[0] = ggml_type_size(meta->type);
meta->nb[1] = meta->nb[0] * (meta->ne[0] / ggml_blck_size(meta->type));
for (int i = 2; i < GGML_MAX_DIMS; i++) {
meta->nb[i] = meta->nb[i-1] * meta->ne[i-1];
}
}
}

ggml_op op = ASR_TENSOR_INFO.at(type);
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
if (!buft) {
Expand All @@ -1721,7 +1805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
ggml_context * ctx = get_ctx(buft);
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);

model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
model.tensors[tensor_name] = tensor;

return tensor;
};
Expand Down Expand Up @@ -1892,41 +1976,49 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
name.assign(&tmp[0], tmp.size());

if (model.tensors.find(name) == model.tensors.end()) {
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
return false;
}

auto tensor = model.tensors[name.data()];
auto tensor = model.tensors[name];

if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
}

if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
__func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}

// Calculate size based on file's tensor type
const size_t file_tensor_size = ggml_row_size(ggml_type(ttype), ne[0]) * (nelements / ne[0]);
const size_t expected_tensor_size = ggml_nbytes(tensor);

// For mixed precision models, the tensor type in file may differ from the type
// the tensor was created with. We need to handle this carefully.
// If we pre-scanned types, the tensor type should already match
// Otherwise (loader doesn't support seek), we need to handle type mismatch here
if (tensor->type != ggml_type(ttype)) {
// Mixed precision: tensor created with one type, file has another
// We need to update the tensor's type to match the file
// Type mismatch - this happens when loader doesn't support seek
// or when tensor wasn't found during pre-scan
if (!tensor_type_map.empty()) {
// We pre-scanned but types still don't match - this is unexpected
WHISPER_LOG_ERROR("%s: tensor '%s' type mismatch after pre-scan: expected %s, file has %s\n",
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
return false;
}

// Loader doesn't support seek - handle type mismatch at runtime (legacy path)
WHISPER_LOG_DEBUG("%s: tensor '%s' type mismatch (expected %s, file has %s)\n",
__func__, name.data(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));

// Check if the allocated buffer is large enough for the file's data
if (file_tensor_size > expected_tensor_size) {
WHISPER_LOG_ERROR("%s: tensor '%s' buffer too small: allocated %zu bytes for %s, but file needs %zu bytes for %s\n",
__func__, name.data(), expected_tensor_size, ggml_type_name(tensor->type),
__func__, name.c_str(), expected_tensor_size, ggml_type_name(tensor->type),
file_tensor_size, ggml_type_name(ggml_type(ttype)));
return false;
}
Expand All @@ -1941,10 +2033,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
tensor->nb[i] = tensor->nb[i-1] * tensor->ne[i-1];
}
} else {
// Normal case: types match, verify size
// Types match, verify size
if (file_tensor_size != expected_tensor_size) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), expected_tensor_size, file_tensor_size);
__func__, name.c_str(), expected_tensor_size, file_tensor_size);
return false;
}
}
Expand Down Expand Up @@ -3689,6 +3781,25 @@ struct whisper_context * whisper_init_from_file_with_params_no_state(const char
fin->close();
};

loader.skip = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->seekg(offset, std::ios::cur);
return fin->good();
};

loader.seek = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->clear(); // clear any error flags
fin->seekg(offset, std::ios::beg);
return fin->good();
};

loader.tell = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
auto pos = fin->tellg();
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
};

auto ctx = whisper_init_with_params_no_state(&loader, params);

if (ctx) {
Expand Down Expand Up @@ -3732,6 +3843,29 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu

loader.close = [](void * /*ctx*/) { };

loader.skip = [](void * ctx, size_t offset) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
if (buf->current_offset + offset > buf->size) {
return false;
}
buf->current_offset += offset;
return true;
};

loader.seek = [](void * ctx, size_t offset) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
if (offset > buf->size) {
return false;
}
buf->current_offset = offset;
return true;
};

loader.tell = [](void * ctx) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
return buf->current_offset;
};

return whisper_init_with_params_no_state(&loader, params);
}

Expand Down Expand Up @@ -4782,6 +4916,25 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
fin->close();
};

loader.skip = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->seekg(offset, std::ios::cur);
return fin->good();
};

loader.seek = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->clear();
fin->seekg(offset, std::ios::beg);
return fin->good();
};

loader.tell = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
auto pos = fin->tellg();
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
};

auto ctx = whisper_vad_init_with_params(&loader, params);
if (!ctx) {
whisper_vad_free(ctx);
Expand Down