diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index d3d6c5bc..d0b513fe 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -769,9 +769,12 @@ open class WhisperKit { } try Task.checkCancellation() + let childProgress = Progress() + progress.totalUnitCount += 1 + progress.addChild(childProgress, withPendingUnitCount: 1) let transcribeTask = TranscribeTask( currentTimings: currentTimings, - progress: progress, + progress: childProgress, audioEncoder: audioEncoder, featureExtractor: featureExtractor, segmentSeeker: segmentSeeker, diff --git a/Tests/WhisperKitTests/TestUtils.swift b/Tests/WhisperKitTests/TestUtils.swift index a173fb8b..d69ec56b 100644 --- a/Tests/WhisperKitTests/TestUtils.swift +++ b/Tests/WhisperKitTests/TestUtils.swift @@ -1,4 +1,5 @@ import CoreML +import Combine import Foundation @testable import WhisperKit import XCTest @@ -274,3 +275,11 @@ extension Collection where Element == TranscriptionResult { flatMap(\.segments) } } + +extension Publisher { + public func withPrevious() -> AnyPublisher<(previous: Output?, current: Output), Failure> { + scan((Output?, Output)?.none) { ($0?.1, $1) } + .compactMap { $0 } + .eraseToAnyPublisher() + } +} diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index f14029a1..bcd09d3a 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -1,6 +1,7 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. +import Combine import AVFoundation import CoreML import Hub @@ -1178,6 +1179,24 @@ final class UnitTests: XCTestCase { XCTAssertTrue(chunkedResult.text.normalized.contains("But then came my 90 page senior".normalized), "Expected text not found in \(chunkedResult.text.normalized)") } + func testVADProgress() async throws { + let pipe = try await WhisperKit(model: "tiny.en") + + let cancellable: AnyCancellable? = pipe.progress.publisher(for: \.fractionCompleted) + .removeDuplicates() + .withPrevious() + .sink { previous, current in + if let previous { + XCTAssertLessThan(previous, current) + } + } + _ = try await pipe.transcribe( + audioPath: Bundle.module.path(forResource: "ted_60", ofType: "m4a")!, + decodeOptions: .init(chunkingStrategy: .vad) + ) + cancellable?.cancel() + } + // MARK: - Word Timestamp Tests func testDynamicTimeWarpingSimpleMatrix() {