Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated swift-transformers, do not use background url session in CLI #74

Merged
merged 6 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "24605a8c0cc974bec5b94a6752eb687bae77db31",
"version" : "0.1.3"
"revision" : "3bd02269b7797ade67c15679a575cd5c6f203ce6",
"version" : "0.1.5"
}
}
],
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.3"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.5"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
],
targets: [
Expand Down
5 changes: 3 additions & 2 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,11 @@ public func resolveAbsolutePath(_ inputPath: String) -> String {

func loadTokenizer(
for pretrained: ModelVariant,
tokenizerFolder: URL? = nil
tokenizerFolder: URL? = nil,
useBackgroundSession: Bool = false
) async throws -> Tokenizer {
let tokenizerName = tokenizerNameForVariant(pretrained)
let hubApi = HubApi(downloadBase: tokenizerFolder)
let hubApi = HubApi(downloadBase: tokenizerFolder, useBackgroundSession: useBackgroundSession)
return try await AutoTokenizer.from(pretrained: tokenizerName, hubApi: hubApi)
}

Expand Down
47 changes: 38 additions & 9 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ public class WhisperKit: Transcriber {
public var modelVariant: ModelVariant = .tiny
public var modelState: ModelState = .unloaded
public var modelCompute: ModelComputeOptions
public var modelFolder: URL?
public var tokenizerFolder: URL?
public var tokenizer: Tokenizer?

/// Protocols
Expand All @@ -48,7 +46,13 @@ public class WhisperKit: Transcriber {
public var decoderInputs: DecodingInputs?
public var currentTimings: TranscriptionTimings?

/// State
public let progress = Progress()

/// Configuration
public var modelFolder: URL?
public var tokenizerFolder: URL?
private let useBackgroundDownloadSession: Bool

public init(
model: String? = nil,
Expand All @@ -67,7 +71,8 @@ public class WhisperKit: Transcriber {
logLevel: Logging.LogLevel = .info,
prewarm: Bool? = nil,
load: Bool? = nil,
download: Bool = true
download: Bool = true,
useBackgroundDownloadSession: Bool = false
) async throws {
self.modelCompute = computeOptions ?? ModelComputeOptions()
self.audioProcessor = audioProcessor ?? AudioProcessor()
Expand All @@ -77,6 +82,7 @@ public class WhisperKit: Transcriber {
self.logitsFilters = logitsFilters ?? []
self.segmentSeeker = segmentSeeker ?? SegmentSeeker()
self.tokenizerFolder = tokenizerFolder
self.useBackgroundDownloadSession = useBackgroundDownloadSession
Logging.shared.logLevel = verbose ? logLevel : .none
currentTimings = TranscriptionTimings()

Expand Down Expand Up @@ -170,8 +176,14 @@ public class WhisperKit: Transcriber {
return sortedModels
}

public static func download(variant: String, downloadBase: URL? = nil, from repo: String = "argmaxinc/whisperkit-coreml", progressCallback: ((Progress) -> Void)? = nil) async throws -> URL? {
let hubApi = HubApi(downloadBase: downloadBase)
public static func download(
variant: String,
downloadBase: URL? = nil,
useBackgroundSession: Bool = false,
from repo: String = "argmaxinc/whisperkit-coreml",
progressCallback: ((Progress) -> Void)? = nil
) async throws -> URL? {
let hubApi = HubApi(downloadBase: downloadBase, useBackgroundSession: useBackgroundSession)
let repo = Hub.Repo(id: repo, type: .models)
do {
let modelFolder = try await hubApi.snapshot(from: repo, matching: ["*\(variant.description)/*"]) { progress in
Expand All @@ -191,7 +203,13 @@ public class WhisperKit: Transcriber {
}

/// Sets up the model folder either from a local path or by downloading from a repository.
public func setupModels(model: String?, downloadBase: URL? = nil, modelRepo: String?, modelFolder: String?, download: Bool) async throws {
public func setupModels(
model: String?,
downloadBase: URL? = nil,
modelRepo: String?,
modelFolder: String?,
download: Bool
) async throws {
// Determine the model variant to use
let modelVariant = model ?? WhisperKit.recommendedModels().default

Expand All @@ -201,7 +219,12 @@ public class WhisperKit: Transcriber {
} else if download {
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
do {
let hubModelFolder = try await Self.download(variant: modelVariant, downloadBase: downloadBase, from: repo)
let hubModelFolder = try await Self.download(
variant: modelVariant,
downloadBase: downloadBase,
useBackgroundSession: useBackgroundDownloadSession,
from: repo
)
self.modelFolder = hubModelFolder!
} catch {
// Handle errors related to model downloading
Expand All @@ -217,7 +240,9 @@ public class WhisperKit: Transcriber {
try await loadModels(prewarmMode: true)
}

public func loadModels(prewarmMode: Bool = false) async throws {
public func loadModels(
prewarmMode: Bool = false
) async throws {
modelState = prewarmMode ? .prewarming : .loading

let modelLoadStart = CFAbsoluteTimeGetCurrent()
Expand Down Expand Up @@ -292,7 +317,11 @@ public class WhisperKit: Transcriber {
{
modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim)
Logging.debug("Loading tokenizer for \(modelVariant)")
tokenizer = try await loadTokenizer(for: modelVariant, tokenizerFolder: tokenizerFolder)
tokenizer = try await loadTokenizer(
for: modelVariant,
tokenizerFolder: tokenizerFolder,
useBackgroundSession: useBackgroundDownloadSession
)
textDecoder.tokenizer = tokenizer
Logging.debug("Loaded tokenizer")
} else {
Expand Down
35 changes: 24 additions & 11 deletions Sources/WhisperKitCLI/Transcribe.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct Transcribe: AsyncParsableCommand {
abstract: "Transcribe audio to text using WhisperKit"
)

@OptionGroup
@OptionGroup
var cliArguments: CLIArguments

mutating func run() async throws {
Expand All @@ -36,32 +36,39 @@ struct Transcribe: AsyncParsableCommand {
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
)

let downloadTokenizerFolder: URL? =
if let filePath = cliArguments.downloadTokenizerPath {
URL(filePath: filePath)
} else {
nil
}

let downloadModelFolder: URL? =
if let filePath = cliArguments.downloadModelPath {
URL(filePath: filePath)
} else {
nil
}

print("Initializing models...")
if cliArguments.verbose {
print("Initializing models...")
}

let whisperKit = try await WhisperKit(
model: cliArguments.model,
downloadBase: downloadModelFolder,
modelFolder: cliArguments.modelPath,
tokenizerFolder: downloadTokenizerFolder,
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug
logLevel: .debug,
useBackgroundDownloadSession: false
)
print("Models initialized")

if cliArguments.verbose {
print("Models initialized")
}

let options = DecodingOptions(
verbose: cliArguments.verbose,
Expand All @@ -83,7 +90,7 @@ struct Transcribe: AsyncParsableCommand {
)

let transcribeResult = try await whisperKit.transcribe(
audioPath: resolvedAudioPath,
audioPath: resolvedAudioPath,
decodeOptions: options
)

Expand Down Expand Up @@ -136,26 +143,32 @@ struct Transcribe: AsyncParsableCommand {
} else {
nil
}

let downloadModelFolder: URL? =
if let filePath = cliArguments.downloadModelPath {
URL(filePath: filePath)
} else {
nil
}

print("Initializing models...")
if cliArguments.verbose {
print("Initializing models...")
}

let whisperKit = try await WhisperKit(
model: cliArguments.model,
downloadBase: downloadModelFolder,
modelFolder: cliArguments.modelPath,
tokenizerFolder: downloadTokenizerFolder,
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug
logLevel: .debug,
useBackgroundDownloadSession: false
)
print("Models initialized")

if cliArguments.verbose {
print("Models initialized")
}
let decodingOptions = DecodingOptions(
verbose: cliArguments.verbose,
task: .transcribe,
Expand Down