diff --git a/Package.resolved b/Package.resolved index d2a1e341..ca63aa18 100644 --- a/Package.resolved +++ b/Package.resolved @@ -14,8 +14,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "24605a8c0cc974bec5b94a6752eb687bae77db31", - "version" : "0.1.3" + "revision" : "3bd02269b7797ade67c15679a575cd5c6f203ce6", + "version" : "0.1.5" } } ], diff --git a/Package.swift b/Package.swift index a7a59961..b0dce2f1 100644 --- a/Package.swift +++ b/Package.swift @@ -20,7 +20,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.3"), + .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.5"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), ], targets: [ diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index a2cf11bd..9f46247d 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -272,10 +272,11 @@ public func resolveAbsolutePath(_ inputPath: String) -> String { func loadTokenizer( for pretrained: ModelVariant, - tokenizerFolder: URL? = nil + tokenizerFolder: URL? = nil, + useBackgroundSession: Bool = false ) async throws -> Tokenizer { let tokenizerName = tokenizerNameForVariant(pretrained) - let hubApi = HubApi(downloadBase: tokenizerFolder) + let hubApi = HubApi(downloadBase: tokenizerFolder, useBackgroundSession: useBackgroundSession) return try await AutoTokenizer.from(pretrained: tokenizerName, hubApi: hubApi) } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 97b301d4..4b5c9a05 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -21,8 +21,6 @@ public class WhisperKit: Transcriber { public var modelVariant: ModelVariant = .tiny public var modelState: ModelState = .unloaded public var modelCompute: ModelComputeOptions - public var modelFolder: URL? - public var tokenizerFolder: URL? public var tokenizer: Tokenizer? /// Protocols @@ -48,7 +46,13 @@ public class WhisperKit: Transcriber { public var decoderInputs: DecodingInputs? public var currentTimings: TranscriptionTimings? + /// State public let progress = Progress() + + /// Configuration + public var modelFolder: URL? + public var tokenizerFolder: URL? + private let useBackgroundDownloadSession: Bool public init( model: String? = nil, @@ -67,7 +71,8 @@ public class WhisperKit: Transcriber { logLevel: Logging.LogLevel = .info, prewarm: Bool? = nil, load: Bool? = nil, - download: Bool = true + download: Bool = true, + useBackgroundDownloadSession: Bool = false ) async throws { self.modelCompute = computeOptions ?? ModelComputeOptions() self.audioProcessor = audioProcessor ?? AudioProcessor() @@ -77,6 +82,7 @@ public class WhisperKit: Transcriber { self.logitsFilters = logitsFilters ?? [] self.segmentSeeker = segmentSeeker ?? SegmentSeeker() self.tokenizerFolder = tokenizerFolder + self.useBackgroundDownloadSession = useBackgroundDownloadSession Logging.shared.logLevel = verbose ? logLevel : .none currentTimings = TranscriptionTimings() @@ -170,8 +176,14 @@ public class WhisperKit: Transcriber { return sortedModels } - public static func download(variant: String, downloadBase: URL? = nil, from repo: String = "argmaxinc/whisperkit-coreml", progressCallback: ((Progress) -> Void)? = nil) async throws -> URL? { - let hubApi = HubApi(downloadBase: downloadBase) + public static func download( + variant: String, + downloadBase: URL? = nil, + useBackgroundSession: Bool = false, + from repo: String = "argmaxinc/whisperkit-coreml", + progressCallback: ((Progress) -> Void)? = nil + ) async throws -> URL? { + let hubApi = HubApi(downloadBase: downloadBase, useBackgroundSession: useBackgroundSession) let repo = Hub.Repo(id: repo, type: .models) do { let modelFolder = try await hubApi.snapshot(from: repo, matching: ["*\(variant.description)/*"]) { progress in @@ -191,7 +203,13 @@ public class WhisperKit: Transcriber { } /// Sets up the model folder either from a local path or by downloading from a repository. - public func setupModels(model: String?, downloadBase: URL? = nil, modelRepo: String?, modelFolder: String?, download: Bool) async throws { + public func setupModels( + model: String?, + downloadBase: URL? = nil, + modelRepo: String?, + modelFolder: String?, + download: Bool + ) async throws { // Determine the model variant to use let modelVariant = model ?? WhisperKit.recommendedModels().default @@ -201,7 +219,12 @@ public class WhisperKit: Transcriber { } else if download { let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" do { - let hubModelFolder = try await Self.download(variant: modelVariant, downloadBase: downloadBase, from: repo) + let hubModelFolder = try await Self.download( + variant: modelVariant, + downloadBase: downloadBase, + useBackgroundSession: useBackgroundDownloadSession, + from: repo + ) self.modelFolder = hubModelFolder! } catch { // Handle errors related to model downloading @@ -217,7 +240,9 @@ public class WhisperKit: Transcriber { try await loadModels(prewarmMode: true) } - public func loadModels(prewarmMode: Bool = false) async throws { + public func loadModels( + prewarmMode: Bool = false + ) async throws { modelState = prewarmMode ? .prewarming : .loading let modelLoadStart = CFAbsoluteTimeGetCurrent() @@ -292,7 +317,11 @@ public class WhisperKit: Transcriber { { modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim) Logging.debug("Loading tokenizer for \(modelVariant)") - tokenizer = try await loadTokenizer(for: modelVariant, tokenizerFolder: tokenizerFolder) + tokenizer = try await loadTokenizer( + for: modelVariant, + tokenizerFolder: tokenizerFolder, + useBackgroundSession: useBackgroundDownloadSession + ) textDecoder.tokenizer = tokenizer Logging.debug("Loaded tokenizer") } else { diff --git a/Sources/WhisperKitCLI/Transcribe.swift b/Sources/WhisperKitCLI/Transcribe.swift index 684afc6f..fd60e116 100644 --- a/Sources/WhisperKitCLI/Transcribe.swift +++ b/Sources/WhisperKitCLI/Transcribe.swift @@ -12,7 +12,7 @@ struct Transcribe: AsyncParsableCommand { abstract: "Transcribe audio to text using WhisperKit" ) - @OptionGroup + @OptionGroup var cliArguments: CLIArguments mutating func run() async throws { @@ -36,14 +36,14 @@ struct Transcribe: AsyncParsableCommand { audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits, textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits ) - + let downloadTokenizerFolder: URL? = if let filePath = cliArguments.downloadTokenizerPath { URL(filePath: filePath) } else { nil } - + let downloadModelFolder: URL? = if let filePath = cliArguments.downloadModelPath { URL(filePath: filePath) @@ -51,7 +51,10 @@ struct Transcribe: AsyncParsableCommand { nil } - print("Initializing models...") + if cliArguments.verbose { + print("Initializing models...") + } + let whisperKit = try await WhisperKit( model: cliArguments.model, downloadBase: downloadModelFolder, @@ -59,9 +62,13 @@ struct Transcribe: AsyncParsableCommand { tokenizerFolder: downloadTokenizerFolder, computeOptions: computeOptions, verbose: cliArguments.verbose, - logLevel: .debug + logLevel: .debug, + useBackgroundDownloadSession: false ) - print("Models initialized") + + if cliArguments.verbose { + print("Models initialized") + } let options = DecodingOptions( verbose: cliArguments.verbose, @@ -83,7 +90,7 @@ struct Transcribe: AsyncParsableCommand { ) let transcribeResult = try await whisperKit.transcribe( - audioPath: resolvedAudioPath, + audioPath: resolvedAudioPath, decodeOptions: options ) @@ -136,7 +143,7 @@ struct Transcribe: AsyncParsableCommand { } else { nil } - + let downloadModelFolder: URL? = if let filePath = cliArguments.downloadModelPath { URL(filePath: filePath) @@ -144,7 +151,10 @@ struct Transcribe: AsyncParsableCommand { nil } - print("Initializing models...") + if cliArguments.verbose { + print("Initializing models...") + } + let whisperKit = try await WhisperKit( model: cliArguments.model, downloadBase: downloadModelFolder, @@ -152,10 +162,13 @@ struct Transcribe: AsyncParsableCommand { tokenizerFolder: downloadTokenizerFolder, computeOptions: computeOptions, verbose: cliArguments.verbose, - logLevel: .debug + logLevel: .debug, + useBackgroundDownloadSession: false ) - print("Models initialized") + if cliArguments.verbose { + print("Models initialized") + } let decodingOptions = DecodingOptions( verbose: cliArguments.verbose, task: .transcribe,