Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples : add sample SAM inference #74

Merged
merged 8 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ add_subdirectory(dolly-v2)
add_subdirectory(replit)
add_subdirectory(mpt)
add_subdirectory(starcoder)
add_subdirectory(sam)
43 changes: 43 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,46 @@ float similarity(const std::string & s0, const std::string & s1) {

return 1.0f - (dist / std::max(s0.size(), s1.size()));
}

bool sam_params_parse(int argc, char ** argv, sam_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];

if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]);
} else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--inp") {
params.fname_inp = argv[++i];
} else if (arg == "-o" || arg == "--out") {
params.fname_out = argv[++i];
} else if (arg == "-h" || arg == "--help") {
sam_print_usage(argc, argv, params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
sam_print_usage(argc, argv, params);
exit(0);
}
}

return true;
}

void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -i FNAME, --inp FNAME\n");
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, " -o FNAME, --out FNAME\n");
fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
19 changes: 18 additions & 1 deletion examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#define COMMON_SAMPLE_RATE 16000

//
// CLI argument parsing
// GPT CLI argument parsing
//

struct gpt_params {
Expand Down Expand Up @@ -157,3 +157,20 @@ bool vad_simple(

// compute similarity between two strings using Levenshtein distance
float similarity(const std::string & s0, const std::string & s1);

//
// SAM argument parsing
//

struct sam_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
std::string fname_inp = "img.jpg";
std::string fname_out = "img.out";
};

bool sam_params_parse(int argc, char ** argv, sam_params & params);

void sam_print_usage(int argc, char ** argv, const sam_params & params);
13 changes: 13 additions & 0 deletions examples/sam/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# sam

set(TEST_TARGET sam)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)

#
# sam-quantize

#set(TEST_TARGET sam-quantize)
#add_executable(${TEST_TARGET} quantize.cpp)
#target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
134 changes: 134 additions & 0 deletions examples/sam/convert-pth-to-ggml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Convert a SAM model checkpoint to a ggml compatible file
#

import os
import sys
import code
import json
import torch
import struct
import numpy as np

if len(sys.argv) < 3:
print("Usage: convert-pth-to-ggml.py file-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)

# output in the same directory as the model
fname_model = sys.argv[1]
fname_out = os.path.dirname(fname_model) + "/ggml-model.bin"

# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])

if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))
sys.exit(1)

fname_out = fname_out.replace(".bin", "-" + ftype_str[ftype] + ".bin")

model = torch.load(fname_model, map_location="cpu")

# TODO: determine based on model data
# TODO: add decoder / prompt encoder if needed
hparams = {
"n_enc_state": 768,
"n_enc_layers": 12,
"n_enc_heads": 12,
"n_enc_out_chans": 256,

"n_pt_embd": 4,
}

print(hparams)

for k, v in model.items():
print(k, v.shape)

#exit()
#code.interact(local=locals())

fout = open(fname_out, "wb")

fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["n_enc_state"]))
fout.write(struct.pack("i", hparams["n_enc_layers"]))
fout.write(struct.pack("i", hparams["n_enc_heads"]))
fout.write(struct.pack("i", hparams["n_enc_out_chans"]))
fout.write(struct.pack("i", hparams["n_pt_embd"]))
fout.write(struct.pack("i", ftype))

for k, v in model.items():
name = k
shape = v.shape

if name[:19] == "prompt_encoder.mask":
continue

print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)

#data = tf.train.load_variable(dir_model, name).squeeze()
#data = v.numpy().squeeze()
data = v.numpy()
n_dims = len(data.shape);

# for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w"
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
#if name[-14:] == "/attn/c_attn/w" or \
# name[-14:] == "/attn/c_proj/w" or \
# name[-11:] == "/mlp/c_fc/w" or \
# name[-13:] == "/mlp/c_proj/w":
# print(" Transposing")
# data = data.transpose()

dshape = data.shape

# default type is fp16
ftype_cur = 1
if ftype == 0 or n_dims == 1 or \
name == "image_encoder.pos_embed" or \
name.startswith("prompt_encoder") or \
name.startswith("mask_decoder.iou_token") or \
name.startswith("mask_decoder.mask_tokens"):
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
print(" Converting to float16")
data = data.astype(np.float16)

# reshape the 1D bias into a 4D tensor so we can use ggml_repeat
# keep it in F32 since the data is small
if name == "image_encoder.patch_embed.proj.bias":
data = data.reshape(1, data.shape[0], 1, 1)
n_dims = len(data.shape);
dshape = data.shape

print(" New shape: ", dshape)

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(str);

# data
data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")
Loading