Skip to content

Commit

Permalink
Tests for offline mode, and model installation checks
Browse files Browse the repository at this point in the history
  • Loading branch information
bmurray committed Nov 1, 2024
1 parent 7871808 commit c922681
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Configurations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ open class WhisperKitConfig {

public init(model: String? = nil,
downloadBase: URL? = nil,
modelRepo: String = "argmaxinc/whisperkit-coreml",
modelRepo: String? = nil,
modelFolder: String? = nil,
tokenizerFolder: URL? = nil,
computeOptions: ModelComputeOptions? = nil,
Expand Down
20 changes: 12 additions & 8 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ open class WhisperKit {
try await setupModels(
model: config.model,
downloadBase: config.downloadBase,
modelRepo: config.modelRepo ?? "argmaxinc/whisperkit-coreml",
modelRepo: config.modelRepo,
modelFolder: config.modelFolder,
download: config.download
)
Expand All @@ -79,7 +79,7 @@ open class WhisperKit {
public convenience init(
model: String? = nil,
downloadBase: URL? = nil,
modelRepo: String = "argmaxinc/whisperkit-coreml",
modelRepo: String? = nil,
modelFolder: String? = nil,
tokenizerFolder: URL? = nil,
computeOptions: ModelComputeOptions? = nil,
Expand Down Expand Up @@ -294,14 +294,16 @@ open class WhisperKit {

var variantPath: String? = nil

if uniquePaths.count == 1 {
if uniquePaths.count == 0 {
throw WhisperError.modelsUnavailable("Could not find model matching \"\(modelSearchPath)\"")
} else if uniquePaths.count == 1 {
variantPath = uniquePaths.first
} else if specificOnly {
// We only want the one specific model, and won't accept fuzzy fallbacks
throw WhisperError.modelsUnavailable("Multiple models found matching \"\(modelSearchPath)\" and specificOnly was set")
} else {
// If the model name search returns more than one unique model folder, then prepend the default "openai" prefix from whisperkittools to disambiguate
Logging.debug("Multiple models found matching \"\(modelSearchPath)\"")
Logging.debug("No definitive model matching \"\(modelSearchPath)\"")
let adjustedModelSearchPath = "*openai*\(variant.description)"
Logging.debug("Searching for models matching \"\(adjustedModelSearchPath)\" in \(repo)")
let adjustedModelFiles = dirFiles.matching(glob: adjustedModelSearchPath)
Expand All @@ -314,7 +316,7 @@ open class WhisperKit {

guard let variantPath else {
// If there is still ambiguity, throw an error
throw WhisperError.modelsUnavailable("Multiple models found matching \"\(modelSearchPath)\"")
throw WhisperError.modelsUnavailable("Could not find definitive model matching \"\(modelSearchPath)\"")
}

let truePath = repoDestination.appending(path: variantPath)
Expand Down Expand Up @@ -344,7 +346,7 @@ open class WhisperKit {
open func setupModels(
model: String?,
downloadBase: URL? = nil,
modelRepo: String = "argmaxinc/whisperkit-coreml",
modelRepo: String? = nil,
modelFolder: String?,
download: Bool
) async throws {
Expand All @@ -356,12 +358,13 @@ open class WhisperKit {
let modelSupport = await WhisperKit.recommendedRemoteModels()
let modelVariant = model ?? modelSupport.default

let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
do {
self.modelFolder = try await Self.download(
variant: modelVariant,
downloadBase: downloadBase,
useBackgroundSession: useBackgroundDownloadSession,
from: modelRepo
from: repo
)
} catch {
// Handle errors related to model downloading
Expand All @@ -373,7 +376,8 @@ open class WhisperKit {
} else {
let modelSupport = WhisperKit.recommendedModels()
let modelVariant = model ?? modelSupport.default
let folder = try Self.modelLocation(variant: modelVariant, downloadBase: downloadBase, from: modelRepo)
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
guard let folder = try? Self.modelLocation(variant: modelVariant, downloadBase: downloadBase, from: repo) else { return }
self.modelFolder = folder
}
}
Expand Down
40 changes: 40 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,46 @@ final class UnitTests: XCTestCase {
"Failed to init WhisperKit"
)
}
func testModelIsInstalled() async throws {
XCTAssertTrue(
WhisperKit.modelInstalled(variant: "openai_whisper-tiny"),
"Model was not installed"
)
XCTAssertFalse(
WhisperKit.modelInstalled(variant: "THIS_MODEL_DOES_NOT_EXIST"),
"Model does not exist, but returned true"
)
XCTAssertTrue(
WhisperKit.modelInstalled(variant: "tiny"),
"Model was not installed"
)
XCTAssertTrue(
WhisperKit.modelInstalled(variant: "tiny.en"),
"Model was not installed"
)
let location = try WhisperKit.modelLocation(variant: "tiny")
let location2 = try WhisperKit.modelLocation(variant: "openai_whisper-tiny")
XCTAssertEqual(location, location2, "Auto fallback to OpenAI model returned a different result")
}
func testOfflineMode() async throws {

let pipe = try await WhisperKit(WhisperKitConfig(model: "tiny.en", download: false))

let cancellable: AnyCancellable? = pipe.progress.publisher(for: \.fractionCompleted)
.removeDuplicates()
.withPrevious()
.sink { previous, current in
if let previous {
XCTAssertLessThan(previous, current)
}
}
_ = try await pipe.transcribe(
audioPath: Bundle.current.path(forResource: "ted_60", ofType: "m4a")!,
decodeOptions: .init(chunkingStrategy: .vad)
)
cancellable?.cancel()

}

// MARK: - Config Tests

Expand Down

0 comments on commit c922681

Please sign in to comment.