Skip to content

Commit

Permalink
change combobox logic
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Aug 29, 2024
1 parent 95d5db8 commit d1905bd
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 59 deletions.
34 changes: 30 additions & 4 deletions settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from PySide6.QtWidgets import (QGroupBox, QVBoxLayout, QHBoxLayout, QLabel, QComboBox, QSlider)
from PySide6.QtCore import Qt
from PySide6.QtCore import Qt, Signal
from constants import WHISPER_MODELS
import torch

from utilities import has_bfloat16_support

class SettingsGroupBox(QGroupBox):
device_changed = Signal(str)

def __init__(self, get_compute_and_platform_info_callback, parent=None):
super().__init__("Settings", parent)
self.get_compute_and_platform_info = get_compute_and_platform_info_callback
Expand All @@ -17,14 +22,14 @@ def initUI(self):
hbox1_layout.addWidget(modelLabel)

self.modelComboBox = QComboBox()
self.modelComboBox.addItems(WHISPER_MODELS.keys())
hbox1_layout.addWidget(self.modelComboBox)

computeDeviceLabel = QLabel("Device:")
hbox1_layout.addWidget(computeDeviceLabel)

self.computeDeviceComboBox = QComboBox()
hbox1_layout.addWidget(self.computeDeviceComboBox)
self.computeDeviceComboBox.currentTextChanged.connect(self.on_device_changed)

formatLabel = QLabel("Output:")
hbox1_layout.addWidget(formatLabel)
Expand Down Expand Up @@ -67,7 +72,7 @@ def initUI(self):
self.batchSizeSlider = QSlider(Qt.Horizontal)
self.batchSizeSlider.setMinimum(1)
self.batchSizeSlider.setMaximum(200)
self.batchSizeSlider.setValue(16)
self.batchSizeSlider.setValue(8)
self.batchSizeSlider.setTickPosition(QSlider.TicksBelow)
self.batchSizeSlider.setTickInterval(10)
batch_size_layout.addWidget(self.batchSizeSlider)
Expand All @@ -86,4 +91,25 @@ def update_slider_label(self, slider, label):
def populateComputeDeviceComboBox(self):
available_devices = self.get_compute_and_platform_info()
self.computeDeviceComboBox.addItems(available_devices)
self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cpu"))
if "cuda" in available_devices:
self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cuda"))
else:
self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cpu"))
self.update_model_combobox()

def on_device_changed(self, device):
self.device_changed.emit(device)
self.update_model_combobox()

def update_model_combobox(self):
current_device = self.computeDeviceComboBox.currentText()
self.modelComboBox.clear()

for model_name, model_info in WHISPER_MODELS.items():
if current_device == "cpu" and model_info['precision'] == 'float32':
self.modelComboBox.addItem(model_name)
elif current_device == "cuda":
if model_info['precision'] in ['float32', 'float16']:
self.modelComboBox.addItem(model_name)
elif model_info['precision'] == 'bfloat16' and has_bfloat16_support():
self.modelComboBox.addItem(model_name)
22 changes: 14 additions & 8 deletions utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@ def get_compute_and_platform_info():
return available_devices


def get_supported_quantizations(device_type):
types = ctranslate2.get_supported_compute_types(device_type)
filtered_types = [q for q in types if q != 'int16']
desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8']
sorted_types = [q for q in desired_order if q in filtered_types]
return sorted_types
# def get_supported_quantizations(device_type):
# types = ctranslate2.get_supported_compute_types(device_type)
# filtered_types = [q for q in types if q != 'int16']
# desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8']
# sorted_types = [q for q in desired_order if q in filtered_types]
# return sorted_types

def get_logical_core_count():
return psutil.cpu_count(logical=True)

def get_physical_core_count():
return psutil.cpu_count(logical=False)
def has_bfloat16_support():
if not torch.cuda.is_available():
return False

capability = torch.cuda.get_device_capability()
return capability >= (8, 6)
84 changes: 39 additions & 45 deletions whispers2t_batch_gui.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
import logging
import os
import sys
import traceback
from pathlib import Path
from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QCheckBox, QLabel, QGroupBox, QMessageBox

import torch
from PySide6.QtCore import Qt
import torch
from utilities import get_compute_and_platform_info, get_supported_quantizations
from whispers2t_batch_transcriber import Worker
from PySide6.QtWidgets import (
QApplication,
QCheckBox,
QFileDialog,
QGroupBox,
QHBoxLayout,
QLabel,
QMessageBox,
QPushButton,
QVBoxLayout,
QWidget,
)

from constants import WHISPER_MODELS
from metrics_bar import MetricsBar
from settings import SettingsGroupBox
import logging
import traceback
from constants import WHISPER_MODELS
from utilities import has_bfloat16_support
from whispers2t_batch_transcriber import Worker

def set_cuda_paths():
try:
Expand All @@ -26,18 +39,14 @@ def set_cuda_paths():

set_cuda_paths()

def is_nvidia_gpu_available():
return torch.cuda.is_available() and "nvidia" in torch.cuda.get_device_name(0).lower()

class MainWindow(QWidget):
def __init__(self):
super().__init__()
self.initUI()

def initUI(self):
self.setWindowTitle("chintellalaw.com - for non-commercial use")
initial_height = 400 if is_nvidia_gpu_available() else 370
self.setGeometry(100, 100, 680, initial_height)
self.setGeometry(100, 100, 680, 400)
self.setWindowFlags(self.windowFlags() | Qt.WindowStaysOnTopHint)

main_layout = QVBoxLayout()
Expand Down Expand Up @@ -69,7 +78,8 @@ def initUI(self):
fileExtensionsGroupBox.setLayout(fileExtensionsLayout)
main_layout.addWidget(fileExtensionsGroupBox)

self.settingsGroupBox = SettingsGroupBox(get_compute_and_platform_info, self)
self.settingsGroupBox = SettingsGroupBox(self.get_compute_and_platform_info, self)
self.settingsGroupBox.device_changed.connect(self.on_device_changed)
main_layout.addWidget(self.settingsGroupBox)

selectDirLayout = QHBoxLayout()
Expand Down Expand Up @@ -101,6 +111,16 @@ def closeEvent(self, event):
self.metricsBar.stop_metrics_collector()
super().closeEvent(event)

def get_compute_and_platform_info(self):
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
return devices

def on_device_changed(self, device):
# You can add any additional logic here if needed when the device is changed
pass

def selectDirectory(self):
dirPath = QFileDialog.getExistingDirectory(self, "Select Directory")
if dirPath:
Expand All @@ -122,46 +142,20 @@ def calculate_files_to_process(self):
return total_files

def perform_checks(self):
model = self.settingsGroupBox.modelComboBox.currentText()
device = self.settingsGroupBox.computeDeviceComboBox.currentText()
batch_size = self.settingsGroupBox.batchSizeSlider.value()
beam_size = self.settingsGroupBox.beamSizeSlider.value()

# Check 1: CPU and non-float32 model
if "float32" not in model.lower() and device.lower() == "cpu":
QMessageBox.warning(self, "Invalid Configuration",
"CPU only supports Float 32 computation. Please select a different Whisper model.")
return False

# Check 2: CPU with high batch size
if device.lower() == "cpu" and batch_size > 16:
reply = QMessageBox.warning(self, "Performance Warning",
"When using CPU it is generally recommended to use a batch size of no more than 16 "
"otherwise compute time will actually be worse.\n\n"
"Moreover, if you select a Beam Size greater than one, you should reduce the Batch Size accordingly.\n\n"
"For example:\n"
"- If you select a Beam Size of 2 (double the default value of 1) you would reduce the Batch Size (default value 16) by half.\n"
"- If Beam Size is set to 3 you should reduce the Batch Size to 1/3 of the default level, and so on.\n\nClick OK to proceed.",
# Check: CPU with high batch size
if device.lower() == "cpu" and batch_size > 8:
reply = QMessageBox.warning(self, "Warning",
"When using CPU it is generally recommended to use a batch size of no more than 8 "
"otherwise compute could actually be worse. Use at your own risk.",
QMessageBox.Ok | QMessageBox.Cancel,
QMessageBox.Cancel)
if reply == QMessageBox.Cancel:
return False

# Check 3: GPU compatibility
# Only perform this check if the device is not CPU
if device.lower() != "cpu":
supported_quantizations = get_supported_quantizations(device)
if "float16" in model.lower() and "float16" not in supported_quantizations:
QMessageBox.warning(self, "Incompatible Configuration",
"Your GPU does not support the selected floating point value (float16). "
"Please make another selection.")
return False
if "bfloat16" in model.lower() and "bfloat16" not in supported_quantizations:
QMessageBox.warning(self, "Incompatible Configuration",
"Your GPU does not support the selected floating point value (bfloat16). "
"Please make another selection.")
return False

return True # All checks passed

def processFiles(self):
Expand Down Expand Up @@ -227,4 +221,4 @@ def workerFinished(self, message):
app.setStyle("Fusion")
mainWindow = MainWindow()
mainWindow.show()
sys.exit(app.exec())
sys.exit(app.exec())
4 changes: 2 additions & 2 deletions whispers2t_batch_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import torch

from constants import WHISPER_MODELS
from utilities import get_physical_core_count
from utilities import get_logical_core_count

CPU_THREADS = max(4, get_physical_core_count() - 1)
CPU_THREADS = max(4, get_logical_core_count() - 8)

class Worker(QThread):
finished = Signal(str)
Expand Down

0 comments on commit d1905bd

Please sign in to comment.