From c37caf71755d7d01ee74cd0ae758fc702d89df48 Mon Sep 17 00:00:00 2001 From: Jan Krukowski Date: Fri, 15 Mar 2024 17:18:50 +0100 Subject: [PATCH 1/6] updated swift-transformers, do not use background session in CLI --- Package.resolved | 4 +- Package.swift | 2 +- Sources/WhisperKit/Core/Utils.swift | 5 +- Sources/WhisperKit/Core/WhisperKit.swift | 58 ++++++++++++++++++------ Sources/WhisperKitCLI/Transcribe.swift | 6 ++- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/Package.resolved b/Package.resolved index d2a1e341..e1f8c9ad 100644 --- a/Package.resolved +++ b/Package.resolved @@ -14,8 +14,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "24605a8c0cc974bec5b94a6752eb687bae77db31", - "version" : "0.1.3" + "revision" : "5105f7238bbb71e66a600408714d52db1c254a7d", + "version" : "0.1.4" } } ], diff --git a/Package.swift b/Package.swift index a7a59961..e416081a 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.3"), + .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.4"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), ], targets: [ diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index a2cf11bd..92dc88d3 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -272,10 +272,11 @@ public func resolveAbsolutePath(_ inputPath: String) -> String { func loadTokenizer( for pretrained: ModelVariant, - tokenizerFolder: URL? = nil + tokenizerFolder: URL? = nil, + useBackgroundSession: Bool = true ) async throws -> Tokenizer { let tokenizerName = tokenizerNameForVariant(pretrained) - let hubApi = HubApi(downloadBase: tokenizerFolder) + let hubApi = HubApi(downloadBase: tokenizerFolder, useBackgroundSession: useBackgroundSession) return try await AutoTokenizer.from(pretrained: tokenizerName, hubApi: hubApi) } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 97b301d4..ae9207d8 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -49,6 +49,9 @@ public class WhisperKit: Transcriber { public var currentTimings: TranscriptionTimings? public let progress = Progress() + + /// Configuration + private let useBackgroundDownloadSession: Bool public init( model: String? = nil, @@ -67,7 +70,9 @@ public class WhisperKit: Transcriber { logLevel: Logging.LogLevel = .info, prewarm: Bool? = nil, load: Bool? = nil, - download: Bool = true + download: Bool = true, + useBackgroundDownloadSession: Bool = true + ) async throws { self.modelCompute = computeOptions ?? ModelComputeOptions() self.audioProcessor = audioProcessor ?? AudioProcessor() @@ -77,6 +82,7 @@ public class WhisperKit: Transcriber { self.logitsFilters = logitsFilters ?? [] self.segmentSeeker = segmentSeeker ?? SegmentSeeker() self.tokenizerFolder = tokenizerFolder + self.useBackgroundDownloadSession = useBackgroundDownloadSession Logging.shared.logLevel = verbose ? logLevel : .none currentTimings = TranscriptionTimings() @@ -85,18 +91,19 @@ public class WhisperKit: Transcriber { downloadBase: downloadBase, modelRepo: modelRepo, modelFolder: modelFolder, - download: download + download: download, + useBackgroundDownloadSession: useBackgroundDownloadSession ) if let prewarm = prewarm, prewarm { Logging.info("Prewarming models...") - try await prewarmModels() + try await prewarmModels(useBackgroundDownloadSession: useBackgroundDownloadSession) } // If load is not passed in, load based on whether a modelFolder is passed if load ?? (modelFolder != nil) { Logging.info("Loading models...") - try await loadModels() + try await loadModels(useBackgroundDownloadSession: useBackgroundDownloadSession) } } @@ -170,8 +177,14 @@ public class WhisperKit: Transcriber { return sortedModels } - public static func download(variant: String, downloadBase: URL? = nil, from repo: String = "argmaxinc/whisperkit-coreml", progressCallback: ((Progress) -> Void)? = nil) async throws -> URL? { - let hubApi = HubApi(downloadBase: downloadBase) + public static func download( + variant: String, + downloadBase: URL? = nil, + useBackgroundSession: Bool = true, + from repo: String = "argmaxinc/whisperkit-coreml", + progressCallback: ((Progress) -> Void)? = nil + ) async throws -> URL? { + let hubApi = HubApi(downloadBase: downloadBase, useBackgroundSession: useBackgroundSession) let repo = Hub.Repo(id: repo, type: .models) do { let modelFolder = try await hubApi.snapshot(from: repo, matching: ["*\(variant.description)/*"]) { progress in @@ -191,7 +204,14 @@ public class WhisperKit: Transcriber { } /// Sets up the model folder either from a local path or by downloading from a repository. - public func setupModels(model: String?, downloadBase: URL? = nil, modelRepo: String?, modelFolder: String?, download: Bool) async throws { + public func setupModels( + model: String?, + downloadBase: URL? = nil, + modelRepo: String?, + modelFolder: String?, + download: Bool, + useBackgroundDownloadSession: Bool + ) async throws { // Determine the model variant to use let modelVariant = model ?? WhisperKit.recommendedModels().default @@ -201,7 +221,12 @@ public class WhisperKit: Transcriber { } else if download { let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" do { - let hubModelFolder = try await Self.download(variant: modelVariant, downloadBase: downloadBase, from: repo) + let hubModelFolder = try await Self.download( + variant: modelVariant, + downloadBase: downloadBase, + useBackgroundSession: useBackgroundDownloadSession, + from: repo + ) self.modelFolder = hubModelFolder! } catch { // Handle errors related to model downloading @@ -213,11 +238,14 @@ public class WhisperKit: Transcriber { } } - public func prewarmModels() async throws { - try await loadModels(prewarmMode: true) + public func prewarmModels(useBackgroundDownloadSession: Bool) async throws { + try await loadModels(prewarmMode: true, useBackgroundDownloadSession: useBackgroundDownloadSession) } - public func loadModels(prewarmMode: Bool = false) async throws { + public func loadModels( + prewarmMode: Bool = false, + useBackgroundDownloadSession: Bool + ) async throws { modelState = prewarmMode ? .prewarming : .loading let modelLoadStart = CFAbsoluteTimeGetCurrent() @@ -292,7 +320,11 @@ public class WhisperKit: Transcriber { { modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim) Logging.debug("Loading tokenizer for \(modelVariant)") - tokenizer = try await loadTokenizer(for: modelVariant, tokenizerFolder: tokenizerFolder) + tokenizer = try await loadTokenizer( + for: modelVariant, + tokenizerFolder: tokenizerFolder, + useBackgroundSession: useBackgroundDownloadSession + ) textDecoder.tokenizer = tokenizer Logging.debug("Loaded tokenizer") } else { @@ -379,7 +411,7 @@ public class WhisperKit: Transcriber { } if self.modelState != .loaded { - try await loadModels() + try await loadModels(useBackgroundDownloadSession: useBackgroundDownloadSession) } var timings = currentTimings! diff --git a/Sources/WhisperKitCLI/Transcribe.swift b/Sources/WhisperKitCLI/Transcribe.swift index 684afc6f..959e05ef 100644 --- a/Sources/WhisperKitCLI/Transcribe.swift +++ b/Sources/WhisperKitCLI/Transcribe.swift @@ -59,7 +59,8 @@ struct Transcribe: AsyncParsableCommand { tokenizerFolder: downloadTokenizerFolder, computeOptions: computeOptions, verbose: cliArguments.verbose, - logLevel: .debug + logLevel: .debug, + useBackgroundDownloadSession: false ) print("Models initialized") @@ -152,7 +153,8 @@ struct Transcribe: AsyncParsableCommand { tokenizerFolder: downloadTokenizerFolder, computeOptions: computeOptions, verbose: cliArguments.verbose, - logLevel: .debug + logLevel: .debug, + useBackgroundDownloadSession: false ) print("Models initialized") From fc07917935153b72e303738fc870e3f29bf0117a Mon Sep 17 00:00:00 2001 From: Jan Krukowski Date: Fri, 15 Mar 2024 17:19:53 +0100 Subject: [PATCH 2/6] removed empty line --- Sources/WhisperKit/Core/WhisperKit.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index ae9207d8..17fe4612 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -72,7 +72,6 @@ public class WhisperKit: Transcriber { load: Bool? = nil, download: Bool = true, useBackgroundDownloadSession: Bool = true - ) async throws { self.modelCompute = computeOptions ?? ModelComputeOptions() self.audioProcessor = audioProcessor ?? AudioProcessor() From 07c995341af9d6cf411cc0745d922e02eadae26d Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Fri, 15 Mar 2024 09:28:26 -0700 Subject: [PATCH 3/6] Organize properties --- Sources/WhisperKit/Core/WhisperKit.swift | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 17fe4612..b3d4eb2c 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -21,8 +21,6 @@ public class WhisperKit: Transcriber { public var modelVariant: ModelVariant = .tiny public var modelState: ModelState = .unloaded public var modelCompute: ModelComputeOptions - public var modelFolder: URL? - public var tokenizerFolder: URL? public var tokenizer: Tokenizer? /// Protocols @@ -48,9 +46,12 @@ public class WhisperKit: Transcriber { public var decoderInputs: DecodingInputs? public var currentTimings: TranscriptionTimings? + /// State public let progress = Progress() /// Configuration + public var modelFolder: URL? + public var tokenizerFolder: URL? private let useBackgroundDownloadSession: Bool public init( From eac0e8e656d32840ba3a4cf8bc594b19851ee334 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Fri, 15 Mar 2024 09:37:56 -0700 Subject: [PATCH 4/6] Cleanup stdout in non-verbose mode --- Sources/WhisperKitCLI/Transcribe.swift | 29 ++++++++++++++++++-------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/Sources/WhisperKitCLI/Transcribe.swift b/Sources/WhisperKitCLI/Transcribe.swift index 959e05ef..fd60e116 100644 --- a/Sources/WhisperKitCLI/Transcribe.swift +++ b/Sources/WhisperKitCLI/Transcribe.swift @@ -12,7 +12,7 @@ struct Transcribe: AsyncParsableCommand { abstract: "Transcribe audio to text using WhisperKit" ) - @OptionGroup + @OptionGroup var cliArguments: CLIArguments mutating func run() async throws { @@ -36,14 +36,14 @@ struct Transcribe: AsyncParsableCommand { audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits, textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits ) - + let downloadTokenizerFolder: URL? = if let filePath = cliArguments.downloadTokenizerPath { URL(filePath: filePath) } else { nil } - + let downloadModelFolder: URL? = if let filePath = cliArguments.downloadModelPath { URL(filePath: filePath) @@ -51,7 +51,10 @@ struct Transcribe: AsyncParsableCommand { nil } - print("Initializing models...") + if cliArguments.verbose { + print("Initializing models...") + } + let whisperKit = try await WhisperKit( model: cliArguments.model, downloadBase: downloadModelFolder, @@ -62,7 +65,10 @@ struct Transcribe: AsyncParsableCommand { logLevel: .debug, useBackgroundDownloadSession: false ) - print("Models initialized") + + if cliArguments.verbose { + print("Models initialized") + } let options = DecodingOptions( verbose: cliArguments.verbose, @@ -84,7 +90,7 @@ struct Transcribe: AsyncParsableCommand { ) let transcribeResult = try await whisperKit.transcribe( - audioPath: resolvedAudioPath, + audioPath: resolvedAudioPath, decodeOptions: options ) @@ -137,7 +143,7 @@ struct Transcribe: AsyncParsableCommand { } else { nil } - + let downloadModelFolder: URL? = if let filePath = cliArguments.downloadModelPath { URL(filePath: filePath) @@ -145,7 +151,10 @@ struct Transcribe: AsyncParsableCommand { nil } - print("Initializing models...") + if cliArguments.verbose { + print("Initializing models...") + } + let whisperKit = try await WhisperKit( model: cliArguments.model, downloadBase: downloadModelFolder, @@ -156,8 +165,10 @@ struct Transcribe: AsyncParsableCommand { logLevel: .debug, useBackgroundDownloadSession: false ) - print("Models initialized") + if cliArguments.verbose { + print("Models initialized") + } let decodingOptions = DecodingOptions( verbose: cliArguments.verbose, task: .transcribe, From eef4ba89d949fa57ddc841255541c4bcefe498d8 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Fri, 15 Mar 2024 15:44:07 -0700 Subject: [PATCH 5/6] Update swift-transformers version --- Package.resolved | 4 ++-- Package.swift | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Package.resolved b/Package.resolved index e1f8c9ad..ca63aa18 100644 --- a/Package.resolved +++ b/Package.resolved @@ -14,8 +14,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "5105f7238bbb71e66a600408714d52db1c254a7d", - "version" : "0.1.4" + "revision" : "3bd02269b7797ade67c15679a575cd5c6f203ce6", + "version" : "0.1.5" } } ], diff --git a/Package.swift b/Package.swift index e416081a..b0dce2f1 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.4"), + .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.5"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), ], targets: [ From 7d809a2efdcd243b04061e60d089b7d5543c25b1 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Fri, 15 Mar 2024 15:56:55 -0700 Subject: [PATCH 6/6] Default to non-background download session --- Sources/WhisperKit/Core/Utils.swift | 2 +- Sources/WhisperKit/Core/WhisperKit.swift | 23 ++++++++++------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index 92dc88d3..9f46247d 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -273,7 +273,7 @@ public func resolveAbsolutePath(_ inputPath: String) -> String { func loadTokenizer( for pretrained: ModelVariant, tokenizerFolder: URL? = nil, - useBackgroundSession: Bool = true + useBackgroundSession: Bool = false ) async throws -> Tokenizer { let tokenizerName = tokenizerNameForVariant(pretrained) let hubApi = HubApi(downloadBase: tokenizerFolder, useBackgroundSession: useBackgroundSession) diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index b3d4eb2c..4b5c9a05 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -72,7 +72,7 @@ public class WhisperKit: Transcriber { prewarm: Bool? = nil, load: Bool? = nil, download: Bool = true, - useBackgroundDownloadSession: Bool = true + useBackgroundDownloadSession: Bool = false ) async throws { self.modelCompute = computeOptions ?? ModelComputeOptions() self.audioProcessor = audioProcessor ?? AudioProcessor() @@ -91,19 +91,18 @@ public class WhisperKit: Transcriber { downloadBase: downloadBase, modelRepo: modelRepo, modelFolder: modelFolder, - download: download, - useBackgroundDownloadSession: useBackgroundDownloadSession + download: download ) if let prewarm = prewarm, prewarm { Logging.info("Prewarming models...") - try await prewarmModels(useBackgroundDownloadSession: useBackgroundDownloadSession) + try await prewarmModels() } // If load is not passed in, load based on whether a modelFolder is passed if load ?? (modelFolder != nil) { Logging.info("Loading models...") - try await loadModels(useBackgroundDownloadSession: useBackgroundDownloadSession) + try await loadModels() } } @@ -180,7 +179,7 @@ public class WhisperKit: Transcriber { public static func download( variant: String, downloadBase: URL? = nil, - useBackgroundSession: Bool = true, + useBackgroundSession: Bool = false, from repo: String = "argmaxinc/whisperkit-coreml", progressCallback: ((Progress) -> Void)? = nil ) async throws -> URL? { @@ -209,8 +208,7 @@ public class WhisperKit: Transcriber { downloadBase: URL? = nil, modelRepo: String?, modelFolder: String?, - download: Bool, - useBackgroundDownloadSession: Bool + download: Bool ) async throws { // Determine the model variant to use let modelVariant = model ?? WhisperKit.recommendedModels().default @@ -238,13 +236,12 @@ public class WhisperKit: Transcriber { } } - public func prewarmModels(useBackgroundDownloadSession: Bool) async throws { - try await loadModels(prewarmMode: true, useBackgroundDownloadSession: useBackgroundDownloadSession) + public func prewarmModels() async throws { + try await loadModels(prewarmMode: true) } public func loadModels( - prewarmMode: Bool = false, - useBackgroundDownloadSession: Bool + prewarmMode: Bool = false ) async throws { modelState = prewarmMode ? .prewarming : .loading @@ -411,7 +408,7 @@ public class WhisperKit: Transcriber { } if self.modelState != .loaded { - try await loadModels(useBackgroundDownloadSession: useBackgroundDownloadSession) + try await loadModels() } var timings = currentTimings!