Skip to content

Commit e939dc3

Browse files
authored
whisper : add Core ML support (ggml-org#566)
* coreml : use Core ML encoder inference * coreml : simlpify whisper_encode + log messages * whisper : resolve rebase conflicts * coreml : add scripts for CoreML model generation * bench-all : recognize COREML flag
1 parent c6078a6 commit e939dc3

15 files changed

+1404
-26
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
*.o
22
*.a
33
.cache/
4+
.coreml/
45
.test/
56
.vs/
67
.vscode/
@@ -35,4 +36,6 @@ examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
3536

3637
extra/bench-gg.txt
3738

38-
*.mlmodel*
39+
models/*.mlmodel
40+
models/*.mlmodelc
41+
models/*.mlpackage

CMakeLists.txt

+59-9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ if (APPLE)
5858
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
5959
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
6060
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
61+
62+
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
6163
else()
6264
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
6365
endif()
@@ -90,16 +92,33 @@ endif()
9092

9193
find_package(Threads REQUIRED)
9294

93-
# on APPLE - include Accelerate framework
94-
if (APPLE AND NOT WHISPER_NO_ACCELERATE)
95-
find_library(ACCELERATE_FRAMEWORK Accelerate)
96-
if (ACCELERATE_FRAMEWORK)
97-
message(STATUS "Accelerate framework found")
95+
# on APPLE
96+
if (APPLE)
97+
# include Accelerate framework
98+
if (NOT WHISPER_NO_ACCELERATE)
99+
find_library(ACCELERATE_FRAMEWORK Accelerate)
100+
101+
if (ACCELERATE_FRAMEWORK)
102+
message(STATUS "Accelerate framework found")
98103

99-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
100-
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
101-
else()
102-
message(WARNING "Accelerate framework not found")
104+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
105+
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
106+
else()
107+
message(WARNING "Accelerate framework not found")
108+
endif()
109+
endif()
110+
111+
if (WHISPER_COREML)
112+
find_library(FOUNDATION_FRAMEWORK Foundation)
113+
find_library(COREML_FRAMEWORK CoreML)
114+
115+
if (COREML_FRAMEWORK)
116+
message(STATUS "CoreML framework found")
117+
118+
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
119+
else()
120+
message(WARNING "CoreML framework not found")
121+
endif()
103122
endif()
104123
endif()
105124

@@ -187,6 +206,33 @@ if (WHISPER_PERF)
187206
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
188207
endif()
189208

209+
#
210+
# whisper.coreml - Core ML support
211+
#
212+
213+
if (WHISPER_COREML)
214+
set(TARGET whisper.coreml)
215+
216+
add_library(${TARGET}
217+
coreml/whisper-encoder.h
218+
coreml/whisper-encoder.mm
219+
coreml/whisper-encoder-impl.h
220+
coreml/whisper-encoder-impl.m
221+
)
222+
223+
include(DefaultTargetOptions)
224+
225+
target_include_directories(${TARGET} PUBLIC
226+
.
227+
)
228+
229+
target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
230+
231+
set_target_properties(${TARGET} PROPERTIES
232+
COMPILE_FLAGS "-fobjc-arc"
233+
)
234+
endif()
235+
190236
#
191237
# whisper - this is the main library of the project
192238
#
@@ -206,6 +252,10 @@ target_include_directories(${TARGET} PUBLIC
206252
.
207253
)
208254

255+
if (WHISPER_COREML)
256+
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
257+
endif()
258+
209259
if (MSVC)
210260
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
211261

Makefile

+32-16
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ ifndef WHISPER_NO_ACCELERATE
140140
LDFLAGS += -framework Accelerate
141141
endif
142142
endif
143+
ifdef WHISPER_COREML
144+
CXXFLAGS += -DWHISPER_USE_COREML
145+
LDFLAGS += -framework Foundation -framework CoreML
146+
endif
143147
ifdef WHISPER_OPENBLAS
144148
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
145149
LDFLAGS += -lopenblas
@@ -195,11 +199,23 @@ ggml.o: ggml.c ggml.h
195199
whisper.o: whisper.cpp whisper.h ggml.h
196200
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
197201

198-
libwhisper.a: ggml.o whisper.o
199-
$(AR) rcs libwhisper.a ggml.o whisper.o
202+
ifndef WHISPER_COREML
203+
WHISPER_OBJ = whisper.o
204+
else
205+
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
206+
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
207+
208+
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
209+
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
210+
211+
WHISPER_OBJ = whisper.o whisper-encoder.o whisper-encoder-impl.o
212+
endif
213+
214+
libwhisper.a: ggml.o $(WHISPER_OBJ)
215+
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
200216

201-
libwhisper.so: ggml.o whisper.o
202-
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
217+
libwhisper.so: ggml.o $(WHISPER_OBJ)
218+
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
203219

204220
clean:
205221
rm -f *.o main stream command talk talk-llama bench libwhisper.a libwhisper.so
@@ -213,24 +229,24 @@ CC_SDL=`sdl2-config --cflags --libs`
213229
SRC_COMMON = examples/common.cpp
214230
SRC_COMMON_SDL = examples/common-sdl.cpp
215231

216-
main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
217-
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
232+
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
233+
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
218234
./main -h
219235

220-
bench: examples/bench/bench.cpp ggml.o whisper.o
221-
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
236+
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
237+
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
222238

223-
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
224-
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
239+
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
240+
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
225241

226-
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
227-
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
242+
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
243+
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
228244

229-
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
230-
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
245+
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
246+
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
231247

232-
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
233-
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk-llama $(CC_SDL) $(LDFLAGS)
248+
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
249+
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
234250

235251
#
236252
# Audio samples

coreml/whisper-decoder-impl.h

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
//
2+
// whisper-decoder-impl.h
3+
//
4+
// This file was automatically generated and should not be edited.
5+
//
6+
7+
#import <Foundation/Foundation.h>
8+
#import <CoreML/CoreML.h>
9+
#include <stdint.h>
10+
#include <os/log.h>
11+
12+
NS_ASSUME_NONNULL_BEGIN
13+
14+
15+
/// Model Prediction Input Type
16+
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
17+
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
18+
19+
/// token_data as 1 by 1 matrix of 32-bit integers
20+
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
21+
22+
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
23+
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
24+
- (instancetype)init NS_UNAVAILABLE;
25+
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
26+
27+
@end
28+
29+
30+
/// Model Prediction Output Type
31+
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
32+
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
33+
34+
/// var_1346 as multidimensional array of floats
35+
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
36+
- (instancetype)init NS_UNAVAILABLE;
37+
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
38+
39+
@end
40+
41+
42+
/// Class for model loading and prediction
43+
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
44+
@interface whisper_decoder_impl : NSObject
45+
@property (readonly, nonatomic, nullable) MLModel * model;
46+
47+
/**
48+
URL of the underlying .mlmodelc directory.
49+
*/
50+
+ (nullable NSURL *)URLOfModelInThisBundle;
51+
52+
/**
53+
Initialize whisper_decoder_impl instance from an existing MLModel object.
54+
55+
Usually the application does not use this initializer unless it makes a subclass of whisper_decoder_impl.
56+
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
57+
*/
58+
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
59+
60+
/**
61+
Initialize whisper_decoder_impl instance with the model in this bundle.
62+
*/
63+
- (nullable instancetype)init;
64+
65+
/**
66+
Initialize whisper_decoder_impl instance with the model in this bundle.
67+
68+
@param configuration The model configuration object
69+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
70+
*/
71+
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
72+
73+
/**
74+
Initialize whisper_decoder_impl instance from the model URL.
75+
76+
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
77+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
78+
*/
79+
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
80+
81+
/**
82+
Initialize whisper_decoder_impl instance from the model URL.
83+
84+
@param modelURL URL to the .mlmodelc directory for whisper_decoder_impl.
85+
@param configuration The model configuration object
86+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
87+
*/
88+
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
89+
90+
/**
91+
Construct whisper_decoder_impl instance asynchronously with configuration.
92+
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
93+
94+
@param configuration The model configuration
95+
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
96+
*/
97+
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
98+
99+
/**
100+
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
101+
102+
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
103+
104+
@param modelURL The model URL.
105+
@param configuration The model configuration
106+
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
107+
*/
108+
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
109+
110+
/**
111+
Make a prediction using the standard interface
112+
@param input an instance of whisper_decoder_implInput to predict from
113+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
114+
@return the prediction as whisper_decoder_implOutput
115+
*/
116+
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
117+
118+
/**
119+
Make a prediction using the standard interface
120+
@param input an instance of whisper_decoder_implInput to predict from
121+
@param options prediction options
122+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
123+
@return the prediction as whisper_decoder_implOutput
124+
*/
125+
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
126+
127+
/**
128+
Make a prediction using the convenience interface
129+
@param token_data as 1 by 1 matrix of 32-bit integers:
130+
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
131+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
132+
@return the prediction as whisper_decoder_implOutput
133+
*/
134+
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error;
135+
136+
/**
137+
Batch prediction
138+
@param inputArray array of whisper_decoder_implInput instances to obtain predictions from
139+
@param options prediction options
140+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
141+
@return the predictions as NSArray<whisper_decoder_implOutput *>
142+
*/
143+
- (nullable NSArray<whisper_decoder_implOutput *> *)predictionsFromInputs:(NSArray<whisper_decoder_implInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
144+
@end
145+
146+
NS_ASSUME_NONNULL_END

0 commit comments

Comments
 (0)