Skip to content

Commit

Permalink
Merge pull request #17 from royshil/roy.fix_model_file_download
Browse files Browse the repository at this point in the history
Move models to config folder
  • Loading branch information
royshil authored Sep 17, 2023
2 parents cd4dbdb + 745937a commit 6979c49
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 69 deletions.
2 changes: 1 addition & 1 deletion cmake/BuildWhispercpp.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ if(WIN32)
else(LOCALVOCAL_WITH_CUDA)
# Build with OpenBLAS
set(OpenBLAS_URL "https://github.com/xianyi/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip")
set(OpenBLAS_SHA256 "8E777E406BA7030D21ADB18683D6175E4FA5ADACFBC09982C01E81245B348132")
set(OpenBLAS_SHA256 "6335128ee7117ea2dd2f5f96f76dafc17256c85992637189a2d5f6da0c608163")
ExternalProject_Add(
OpenBLAS
URL ${OpenBLAS_URL}
Expand Down
3 changes: 3 additions & 0 deletions src/model-utils/model-downloader-types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

typedef std::function<void(int download_status, const std::string &path)>
download_finished_callback_t;
79 changes: 55 additions & 24 deletions src/model-utils/model-downloader-ui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include <obs-module.h>

#include <filesystem>

const std::string MODEL_BASE_PATH = "https://huggingface.co/ggerganov/whisper.cpp";
const std::string MODEL_PREFIX = "resolve/main/";

Expand All @@ -12,14 +14,17 @@ size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream)
return written;
}

ModelDownloader::ModelDownloader(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback_, QWidget *parent)
ModelDownloader::ModelDownloader(const std::string &model_name,
download_finished_callback_t download_finished_callback_,
QWidget *parent)
: QDialog(parent), download_finished_callback(download_finished_callback_)
{
this->setWindowTitle("Downloading model...");
this->setWindowTitle("LocalVocal: Downloading model...");
this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint);
this->setFixedSize(300, 100);
// Bring the dialog to the front
this->activateWindow();
this->raise();

this->layout = new QVBoxLayout(this);

Expand Down Expand Up @@ -59,24 +64,32 @@ ModelDownloader::ModelDownloader(
this->download_thread->start();
}

void ModelDownloader::closeEvent(QCloseEvent *e)
{
if (!this->mPrepareToClose)
e->ignore();
else
QDialog::closeEvent(e);
}

void ModelDownloader::close()
{
this->mPrepareToClose = true;

QDialog::close();
}

void ModelDownloader::update_progress(int progress)
{
this->progress_bar->setValue(progress);
}

void ModelDownloader::download_finished()
void ModelDownloader::download_finished(const std::string &path)
{
this->setWindowTitle("Download finished!");
this->progress_bar->setValue(100);
this->progress_bar->setFormat("Download finished!");
this->progress_bar->setAlignment(Qt::AlignCenter);
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #05B8CC; }");
// Add a button to close the dialog
QPushButton *close_button = new QPushButton("Close", this);
this->layout->addWidget(close_button);
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
// Call the callback
this->download_finished_callback(0);
// Call the callback with the path to the downloaded model
this->download_finished_callback(0, path);
// Close the dialog
this->close();
}

void ModelDownloader::show_error(const std::string &reason)
Expand All @@ -96,7 +109,7 @@ void ModelDownloader::show_error(const std::string &reason)
QPushButton *close_button = new QPushButton("Close", this);
this->layout->addWidget(close_button);
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
this->download_finished_callback(1);
this->download_finished_callback(1, "");
}

ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_)
Expand All @@ -106,9 +119,23 @@ ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_)

void ModelDownloadWorker::download_model()
{
std::string module_data_dir = obs_get_module_data_path(obs_current_module());
// join the directory and the filename using the platform-specific separator
std::string model_save_path = module_data_dir + "/" + this->model_name;
char *module_config_path = obs_module_get_config_path(obs_current_module(), "models");
// Check if the config folder exists
if (!std::filesystem::exists(module_config_path)) {
obs_log(LOG_WARNING, "Config folder does not exist: %s", module_config_path);
// Create the config folder
if (!std::filesystem::create_directories(module_config_path)) {
obs_log(LOG_ERROR, "Failed to create config folder: %s",
module_config_path);
emit download_error("Failed to create config folder.");
return;
}
}

char *model_save_path_str =
obs_module_get_config_path(obs_current_module(), this->model_name.c_str());
std::string model_save_path(model_save_path_str);
bfree(model_save_path_str);
obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str());

// extract filename from path in this->modle_name
Expand Down Expand Up @@ -143,11 +170,11 @@ void ModelDownloadWorker::download_model()
}
curl_easy_cleanup(curl);
fclose(fp);
emit download_finished(model_save_path);
} else {
obs_log(LOG_ERROR, "Failed to initialize curl.");
emit download_error("Failed to initialize curl.");
}
emit download_finished();
}

int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
Expand All @@ -168,9 +195,13 @@ int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, cu

ModelDownloader::~ModelDownloader()
{
this->download_thread->quit();
this->download_thread->wait();
delete this->download_thread;
if (this->download_thread != nullptr) {
if (this->download_thread->isRunning()) {
this->download_thread->quit();
this->download_thread->wait();
}
delete this->download_thread;
}
delete this->download_worker;
}

Expand Down
15 changes: 11 additions & 4 deletions src/model-utils/model-downloader-ui.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#include <curl/curl.h>

#include "model-downloader-types.h"

class ModelDownloadWorker : public QObject {
Q_OBJECT
public:
Expand All @@ -20,7 +22,7 @@ public slots:

signals:
void download_progress(int progress);
void download_finished();
void download_finished(const std::string &path);
void download_error(const std::string &reason);

private:
Expand All @@ -33,22 +35,27 @@ class ModelDownloader : public QDialog {
Q_OBJECT
public:
ModelDownloader(const std::string &model_name,
std::function<void(int download_status)> download_finished_callback,
download_finished_callback_t download_finished_callback,
QWidget *parent = nullptr);
~ModelDownloader();

public slots:
void update_progress(int progress);
void download_finished();
void download_finished(const std::string &path);
void show_error(const std::string &reason);

protected:
void closeEvent(QCloseEvent *e) override;

private:
QVBoxLayout *layout;
QProgressBar *progress_bar;
QThread *download_thread;
ModelDownloadWorker *download_worker;
// Callback for when the download is finished
std::function<void(int download_status)> download_finished_callback;
download_finished_callback_t download_finished_callback;
bool mPrepareToClose;
void close();
};

#endif // MODEL_DOWNLOADER_UI_H
45 changes: 30 additions & 15 deletions src/model-utils/model-downloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,43 @@

#include <curl/curl.h>

bool check_if_model_exists(const std::string &model_name)
std::string find_model_file(const std::string &model_name)
{
obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str());
char *model_file_path = obs_module_file(model_name.c_str());
obs_log(LOG_INFO, "Model file path: %s", model_file_path);
const char *model_name_cstr = model_name.c_str();
obs_log(LOG_INFO, "Checking if model %s exists in data...", model_name_cstr);

char *model_file_path = obs_module_file(model_name_cstr);
if (model_file_path == nullptr) {
obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str());
return false;
obs_log(LOG_INFO, "Model %s not found in data.", model_name_cstr);
} else {
std::string model_file_path_str(model_file_path);
bfree(model_file_path);
if (!std::filesystem::exists(model_file_path_str)) {
obs_log(LOG_INFO, "Model not found in data: %s",
model_file_path_str.c_str());
} else {
obs_log(LOG_INFO, "Model found in data: %s", model_file_path_str.c_str());
return model_file_path_str;
}
}

if (!std::filesystem::exists(model_file_path)) {
obs_log(LOG_INFO, "Model %s does not exist.", model_file_path);
bfree(model_file_path);
return false;
// Check if model exists in the config folder
char *model_config_path_str =
obs_module_get_config_path(obs_current_module(), model_name_cstr);
std::string model_config_path(model_config_path_str);
bfree(model_config_path_str);
obs_log(LOG_INFO, "Model path in config: %s", model_config_path.c_str());
if (std::filesystem::exists(model_config_path)) {
obs_log(LOG_INFO, "Model exists in config folder: %s", model_config_path.c_str());
return model_config_path;
}
bfree(model_file_path);
return true;

obs_log(LOG_INFO, "Model %s not found.", model_name_cstr);
return "";
}

void download_model_with_ui_dialog(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback)
void download_model_with_ui_dialog(const std::string &model_name,
download_finished_callback_t download_finished_callback)
{
// Start the model downloader UI
ModelDownloader *model_downloader = new ModelDownloader(
Expand Down
9 changes: 5 additions & 4 deletions src/model-utils/model-downloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
#include <string>
#include <functional>

bool check_if_model_exists(const std::string &model_name);
#include "model-downloader-types.h"

std::string find_model_file(const std::string &model_name);

// Start the model downloader UI dialog with a callback for when the download is finished
void download_model_with_ui_dialog(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback);
void download_model_with_ui_dialog(const std::string &model_name,
download_finished_callback_t download_finished_callback);

#endif // MODEL_DOWNLOADER_H
2 changes: 1 addition & 1 deletion src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct transcription_filter_data {
audio_resampler_t *resampler = nullptr;

/* whisper */
std::string whisper_model_path = "models/ggml-tiny.en.bin";
std::string whisper_model_path;
struct whisper_context *whisper_context = nullptr;
whisper_full_params whisper_params;

Expand Down
29 changes: 10 additions & 19 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,15 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->whisper_model_path = new_model_path;

// check if the model exists, if not, download it
if (!check_if_model_exists(gf->whisper_model_path)) {
obs_log(LOG_ERROR, "Whisper model does not exist");
std::string model_file_found = find_model_file(gf->whisper_model_path);
if (model_file_found == "") {
obs_log(LOG_WARNING, "Whisper model does not exist");
download_model_with_ui_dialog(
gf->whisper_model_path, [gf](int download_status) {
gf->whisper_model_path,
[gf](int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO, "Model download complete");
gf->whisper_context = init_whisper_context(
gf->whisper_model_path);
gf->whisper_context = init_whisper_context(path);
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
} else {
Expand All @@ -329,7 +330,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
});
} else {
// Model exists, just load it
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
gf->whisper_context = init_whisper_context(model_file_found);
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
}
Expand Down Expand Up @@ -374,8 +375,8 @@ void transcription_filter_update(void *data, obs_data_t *s)

void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
{
struct transcription_filter_data *gf = static_cast<struct transcription_filter_data *>(
bzalloc(sizeof(struct transcription_filter_data)));
void *p = bzalloc(sizeof(struct transcription_filter_data));
struct transcription_filter_data *gf = new (p) transcription_filter_data;

// Get the number of channels for the input source
gf->channels = audio_output_get_channels(obs_get_audio());
Expand All @@ -396,12 +397,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
}

gf->context = filter;
gf->whisper_model_path = std::string(obs_data_get_string(settings, "whisper_model_path"));
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
if (gf->whisper_context == nullptr) {
obs_log(LOG_ERROR, "Failed to load whisper model");
return nullptr;
}
gf->whisper_model_path = ""; // The update function will set the model path

gf->overlap_ms = OVERLAP_SIZE_MSEC;
gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
Expand Down Expand Up @@ -433,11 +429,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
// get the settings updated on the filter data struct
transcription_filter_update(gf, settings);

obs_log(gf->log_level, "transcription_filter: start whisper thread");
// start the thread
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);

gf->active = true;

obs_log(gf->log_level, "transcription_filter: filter created.");
Expand Down
2 changes: 1 addition & 1 deletion src/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float v
struct whisper_context *init_whisper_context(const std::string &model_path)
{
obs_log(LOG_INFO, "Loading whisper model from %s", model_path.c_str());
struct whisper_context *ctx = whisper_init_from_file(obs_module_file(model_path.c_str()));
struct whisper_context *ctx = whisper_init_from_file(model_path.c_str());
if (ctx == nullptr) {
obs_log(LOG_ERROR, "Failed to load whisper model");
return nullptr;
Expand Down

0 comments on commit 6979c49

Please sign in to comment.