Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding ggml_to_pt script #1042

Merged
merged 3 commits into from
Jun 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions models/ggml_to_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import struct
import torch
import numpy as np
from collections import OrderedDict
from pathlib import Path
import sys

if len(sys.argv) < 3:
print(
"Usage: convert-ggml-to-pt.py model.bin dir-output\n")
sys.exit(1)

fname_inp = Path(sys.argv[1])
dir_out = Path(sys.argv[2])
fname_out = dir_out / "torch-model.pt"



# Open the ggml file
with open(fname_inp, "rb") as f:
# Read magic number and hyperparameters
magic_number, n_vocab, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, n_text_ctx, n_text_state, n_text_head, n_text_layer, n_mels, use_f16 = struct.unpack("12i", f.read(48))
print(f"Magic number: {magic_number}")
print(f"Vocab size: {n_vocab}")
print(f"Audio context size: {n_audio_ctx}")
print(f"Audio state size: {n_audio_state}")
print(f"Audio head size: {n_audio_head}")
print(f"Audio layer size: {n_audio_layer}")
print(f"Text context size: {n_text_ctx}")
print(f"Text head size: {n_text_head}")
print(f"Mel size: {n_mels}")
# Read mel filters
# mel_filters = np.fromfile(f, dtype=np.float32, count=n_mels * 2).reshape(n_mels, 2)
# print(f"Mel filters: {mel_filters}")
filters_shape_0 = struct.unpack("i", f.read(4))[0]
print(f"Filters shape 0: {filters_shape_0}")
filters_shape_1 = struct.unpack("i", f.read(4))[0]
print(f"Filters shape 1: {filters_shape_1}")

# Read tokenizer tokens
# bytes = f.read(4)
# print(bytes)


# for i in range(filters.shape[0]):
# for j in range(filters.shape[1]):
# fout.write(struct.pack("f", filters[i][j]))
mel_filters = np.zeros((filters_shape_0, filters_shape_1))

for i in range(filters_shape_0):
for j in range(filters_shape_1):
mel_filters[i][j] = struct.unpack("f", f.read(4))[0]

bytes_data = f.read(4)
num_tokens = struct.unpack("i", bytes_data)[0]
tokens = {}


for _ in range(num_tokens):
token_len = struct.unpack("i", f.read(4))[0]
token = f.read(token_len)
tokens[token] = {}

# Read model variables
model_state_dict = OrderedDict()
while True:
try:
n_dims, name_length, ftype = struct.unpack("iii", f.read(12))
except struct.error:
break # End of file
dims = [struct.unpack("i", f.read(4))[0] for _ in range(n_dims)]
dims = dims[::-1]
name = f.read(name_length).decode("utf-8")
if ftype == 1: # f16
data = np.fromfile(f, dtype=np.float16, count=np.prod(dims)).reshape(dims)
else: # f32
data = np.fromfile(f, dtype=np.float32, count=np.prod(dims)).reshape(dims)


if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:

data = data[:, 0]


model_state_dict[name] = torch.from_numpy(data)

# Now you have the model's state_dict stored in model_state_dict
# You can load this state_dict into a model with the same architecture

# dims = ModelDimensions(**checkpoint["dims"])
# model = Whisper(dims)
from whisper import Whisper, ModelDimensions
dims = ModelDimensions(
n_mels=n_mels,
n_audio_ctx=n_audio_ctx,
n_audio_state=n_audio_state,
n_audio_head=n_audio_head,
n_audio_layer=n_audio_layer,
n_text_ctx=n_text_ctx,
n_text_state=n_text_state,
n_text_head=n_text_head,
n_text_layer=n_text_layer,
n_vocab=n_vocab,
)
model = Whisper(dims) # Replace with your model's class
model.load_state_dict(model_state_dict)

# Save the model in PyTorch format
torch.save(model.state_dict(), fname_out)