Skip to content

Commit cf00916

Browse files
committed
[examples/batched.swift] add an example swift implementation of batched & build it in ci to validate SPM is correctly configured for dependencies
1 parent 233fc1c commit cf00916

File tree

8 files changed

+308
-1
lines changed

8 files changed

+308
-1
lines changed

.github/workflows/build.yml

+5
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ jobs:
276276
run: |
277277
xcodebuild -scheme llama -destination "${{ matrix.destination }}"
278278
279+
- name: Build Swift Example
280+
id: make_build_swift_example
281+
run: |
282+
make swift
283+
279284
windows-latest-cmake:
280285
runs-on: windows-latest
281286

Makefile

+6-1
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,11 @@ metal: examples/metal/metal.cpp ggml.o $(OBJS)
617617
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
618618
endif
619619

620+
ifeq ($(UNAME_S),Darwin)
621+
swift: examples/batched.swift
622+
(cd examples/batched.swift; make build)
623+
endif
624+
620625
build-info.h: $(wildcard .git/index) scripts/build-info.sh
621626
@sh scripts/build-info.sh $(CC) > $@.tmp
622627
@if ! cmp -s $@.tmp $@; then \
@@ -637,7 +642,7 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o
637642
run-benchmark-matmult: benchmark-matmult
638643
./$@
639644

640-
.PHONY: run-benchmark-matmult
645+
.PHONY: run-benchmark-matmult swift
641646

642647
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
643648
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

examples/batched.swift/.gitignore

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
.DS_Store
2+
/.build
3+
/Packages
4+
xcuserdata/
5+
DerivedData/
6+
.swiftpm/configuration/registries.json
7+
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
8+
.netrc
9+
swift

examples/batched.swift/Makefile

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.PHONY: build
2+
3+
build:
4+
xcodebuild -scheme batched_swift -destination "generic/platform=macOS" -derivedDataPath build
5+
rm -f ./batched_swift
6+
ln -s ./build/Build/Products/Debug/batched_swift ./batched_swift

examples/batched.swift/Package.swift

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// swift-tools-version: 5.5
2+
// The swift-tools-version declares the minimum version of Swift required to build this package.
3+
4+
import PackageDescription
5+
6+
let package = Package(
7+
name: "batched_swift",
8+
platforms: [.macOS(.v12)],
9+
dependencies: [
10+
.package(name: "llama", path: "../../"),
11+
],
12+
targets: [
13+
// Targets are the basic building blocks of a package, defining a module or a test suite.
14+
// Targets can depend on other targets in this package and products from dependencies.
15+
.executableTarget(
16+
name: "batched_swift",
17+
dependencies: ["llama"],
18+
path: "Sources",
19+
linkerSettings: [.linkedFramework("Foundation"), .linkedFramework("AppKit")]
20+
),
21+
]
22+
)

examples/batched.swift/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
This is a swift clone of `examples/batched`.
2+
3+
$ `make`
4+
$ `./swift MODEL_PATH [PROMPT] [PARALLEL]`
+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import Foundation
2+
import llama
3+
4+
let arguments = CommandLine.arguments
5+
6+
// Check that we have at least one argument (the model path)
7+
guard arguments.count > 1 else {
8+
print("Usage: swift MODEL_PATH [PROMPT] [PARALLEL]")
9+
exit(1)
10+
}
11+
12+
let modelPath: String = arguments[1]
13+
let prompt: String = arguments.count > 2 ? arguments[2] : "Hello my name is"
14+
let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(arguments[3])! : 1
15+
16+
// total length of the sequences including the prompt
17+
let n_len: Int = 32
18+
19+
// init LLM
20+
llama_backend_init(false)
21+
defer {
22+
llama_backend_free()
23+
}
24+
25+
let model_params = llama_model_default_params()
26+
guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else {
27+
print("Failed to load model")
28+
exit(1)
29+
}
30+
31+
defer {
32+
llama_free_model(model)
33+
}
34+
35+
var tokens = tokenize(text: prompt, add_bos: true)
36+
37+
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
38+
39+
var context_params = llama_context_default_params()
40+
context_params.seed = 1234
41+
context_params.n_ctx = n_kv_req
42+
context_params.n_batch = UInt32(max(n_len, n_parallel))
43+
context_params.n_threads = 8
44+
context_params.n_threads_batch = 8
45+
46+
let context = llama_new_context_with_model(model, context_params)
47+
guard context != nil else {
48+
print("Failed to initialize context")
49+
exit(1)
50+
}
51+
52+
defer {
53+
llama_free(context)
54+
}
55+
56+
let n_ctx = llama_n_ctx(context)
57+
58+
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
59+
60+
if n_kv_req > n_ctx {
61+
print("error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", n_kv_req)
62+
exit(1)
63+
}
64+
65+
var buffer: [CChar] = []
66+
for id: llama_token in tokens {
67+
print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "")
68+
}
69+
70+
print("\n")
71+
72+
var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0)
73+
defer {
74+
llama_batch_free(batch)
75+
}
76+
77+
// evaluate the initial prompt
78+
batch.n_tokens = Int32(tokens.count)
79+
80+
for (i, token) in tokens.enumerated() {
81+
batch.token[i] = token
82+
batch.pos[i] = Int32(i)
83+
batch.seq_id[i] = 0
84+
batch.logits[i] = 0
85+
}
86+
87+
// llama_decode will output logits only for the last token of the prompt
88+
batch.logits[Int(batch.n_tokens) - 1] = 1
89+
90+
if llama_decode(context, batch) != 0 {
91+
print("llama_decode() failed")
92+
exit(1)
93+
}
94+
95+
for i in 1 ..< n_parallel {
96+
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
97+
}
98+
99+
if n_parallel > 1 {
100+
print("generating \(n_parallel) sequences ...\n")
101+
}
102+
103+
var streams: [String] = .init(repeating: "", count: n_parallel)
104+
var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel)
105+
var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel)
106+
107+
var n_cur = batch.n_tokens
108+
var n_decode = 0
109+
110+
let t_main_start = ggml_time_us()
111+
112+
while n_cur <= n_len {
113+
// prepare the next batch
114+
batch.n_tokens = 0
115+
116+
// sample the next token for each parallel sequence / stream
117+
for i in 0 ..< n_parallel {
118+
if i_batch[i] < 0 {
119+
// the stream has already finished
120+
continue
121+
}
122+
123+
var n_vocab = llama_n_vocab(model)
124+
var logits = llama_get_logits_ith(context, i_batch[i])
125+
126+
var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
127+
128+
for token_id in 0 ..< n_vocab {
129+
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
130+
}
131+
132+
var candidates_p: llama_token_data_array = .init(
133+
data: &candidates,
134+
size: candidates.count,
135+
sorted: false
136+
)
137+
138+
let top_k: Int32 = 40
139+
let top_p: Float = 0.9
140+
let temp: Float = 0.4
141+
142+
llama_sample_top_k(context, &candidates_p, top_k, 1)
143+
llama_sample_top_p(context, &candidates_p, top_p, 1)
144+
llama_sample_temp(context, &candidates_p, temp)
145+
146+
let new_token_id = llama_sample_token(context, &candidates_p)
147+
148+
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
149+
150+
// is it an end of stream? -> mark the stream as finished
151+
if new_token_id == llama_token_eos(context) || n_cur == n_len {
152+
i_batch[i] = -1
153+
// print("")
154+
if n_parallel > 1 {
155+
print("stream \(i) finished at n_cur = \(n_cur)")
156+
}
157+
158+
continue
159+
}
160+
161+
let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? ""
162+
163+
// if there is only one stream, we print immediately to stdout
164+
if n_parallel == 1 {
165+
print(nextStringPiece, terminator: "")
166+
}
167+
streams[i] += nextStringPiece
168+
169+
// push this new token for next evaluation
170+
batch.token[Int(batch.n_tokens)] = new_token_id
171+
batch.pos[Int(batch.n_tokens)] = n_cur
172+
batch.seq_id[Int(batch.n_tokens)] = Int32(i)
173+
batch.logits[Int(batch.n_tokens)] = 1
174+
175+
i_batch[i] = batch.n_tokens
176+
177+
batch.n_tokens += 1
178+
179+
n_decode += 1
180+
}
181+
182+
// all streams are finished
183+
if batch.n_tokens == 0 {
184+
break
185+
}
186+
187+
n_cur += 1
188+
189+
// evaluate the current batch with the transformer model
190+
if llama_decode(context, batch) != 0 {
191+
print("llama_decode() failed")
192+
exit(1)
193+
}
194+
}
195+
196+
if n_parallel > 1 {
197+
print("\n")
198+
for (i, stream) in streams.enumerated() {
199+
print("sequence \(i):\n\n\(prompt)\(stream)\n")
200+
}
201+
}
202+
203+
let t_main_end = ggml_time_us()
204+
205+
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
206+
207+
llama_print_timings(context)
208+
209+
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
210+
let n_tokens = text.count + (add_bos ? 1 : 0)
211+
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
212+
let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos)
213+
var swiftTokens: [llama_token] = []
214+
for i in 0 ..< tokenCount {
215+
swiftTokens.append(tokens[Int(i)])
216+
}
217+
tokens.deallocate()
218+
return swiftTokens
219+
}
220+
221+
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
222+
var result = [CChar](repeating: 0, count: 8)
223+
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count))
224+
if nTokens < 0 {
225+
if result.count >= -Int(nTokens) {
226+
result.removeLast(-Int(nTokens))
227+
} else {
228+
result.removeAll()
229+
}
230+
let check = llama_token_to_piece(
231+
model,
232+
token,
233+
&result,
234+
Int32(result.count)
235+
)
236+
assert(check == nTokens)
237+
} else {
238+
result.removeLast(result.count - Int(nTokens))
239+
}
240+
if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) {
241+
return utfString
242+
} else {
243+
buffer.append(contentsOf: result)
244+
let data = Data(buffer.map { UInt8(bitPattern: $0) })
245+
if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer
246+
buffer = []
247+
}
248+
guard let bufferString = String(data: data, encoding: .utf8) else {
249+
return nil
250+
}
251+
buffer = []
252+
return bufferString
253+
}
254+
return nil
255+
}

examples/batched.swift/batched_swift

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
./build/Build/Products/Debug/batched_swift

0 commit comments

Comments
 (0)