-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
iOS example with swift ui #4159
Merged
Merged
Changes from 12 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
62ca9dc
copy to llama.cpp as subdir
bachittle f632a3c
attempt enabling metal, fails
bachittle 7d565c9
ggml metal compiles!
bachittle 67a0e8b
Update README.md
bachittle 5e97a60
Merge branch 'master' into swiftui_metal
bachittle ae6beb4
initial conversion to new format, utf8 errors?
bachittle 090383b
bug fixes, but now has an invalid memory access :(
bachittle cd61854
added O3, now has insufficient memory access
bachittle f510cc1
Merge branch 'master' into swiftui_metal_update
bachittle ce31d95
begin sync with master
bachittle a22264a
update to match latest code, new errors
bachittle f002a2e
fixed it!
bachittle 31fbcf6
Merge branch 'master' into swiftui_metal_update
bachittle 1e65f66
fix for loop conditionals, increase result size
bachittle ee07308
fix current workflow errors
bachittle 256478a
attempt a llama.swiftui workflow
bachittle af05571
Update .github/workflows/build.yml
bachittle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
xcuserdata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# llama.swiftui | ||
|
||
Local inference of llama.cpp on an iPhone. | ||
So far I only tested with starcoder 1B model, but it can most likely handle 7B models as well. | ||
|
||
https://github.com/bachittle/llama.cpp/assets/39804642/e290827a-4edb-4093-9642-2a5e399ec545 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import Foundation | ||
|
||
// import llama | ||
|
||
enum LlamaError: Error { | ||
case couldNotInitializeContext | ||
} | ||
|
||
actor LlamaContext { | ||
private var model: OpaquePointer | ||
private var context: OpaquePointer | ||
private var batch: llama_batch | ||
private var tokens_list: [llama_token] | ||
|
||
var n_len: Int32 = 32 | ||
var n_cur: Int32 = 0 | ||
var n_decode: Int32 = 0 | ||
|
||
init(model: OpaquePointer, context: OpaquePointer) { | ||
self.model = model | ||
self.context = context | ||
self.tokens_list = [] | ||
self.batch = llama_batch_init(512, 0, 1) | ||
} | ||
|
||
deinit { | ||
llama_free(context) | ||
llama_free_model(model) | ||
llama_backend_free() | ||
} | ||
|
||
static func createContext(path: String) throws -> LlamaContext { | ||
llama_backend_init(false) | ||
let model_params = llama_model_default_params() | ||
|
||
let model = llama_load_model_from_file(path, model_params) | ||
guard let model else { | ||
print("Could not load model at \(path)") | ||
throw LlamaError.couldNotInitializeContext | ||
} | ||
var ctx_params = llama_context_default_params() | ||
ctx_params.seed = 1234 | ||
ctx_params.n_ctx = 2048 | ||
ctx_params.n_threads = 8 | ||
ctx_params.n_threads_batch = 8 | ||
|
||
let context = llama_new_context_with_model(model, ctx_params) | ||
guard let context else { | ||
print("Could not load context!") | ||
throw LlamaError.couldNotInitializeContext | ||
} | ||
|
||
return LlamaContext(model: model, context: context) | ||
} | ||
|
||
func get_n_tokens() -> Int32 { | ||
return batch.n_tokens; | ||
} | ||
|
||
func completion_init(text: String) { | ||
print("attempting to complete \"\(text)\"") | ||
|
||
tokens_list = tokenize(text: text, add_bos: true) | ||
|
||
let n_ctx = llama_n_ctx(context) | ||
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count) | ||
|
||
print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)") | ||
|
||
if n_kv_req > n_ctx { | ||
print("error: n_kv_req > n_ctx, the required KV cache size is not big enough") | ||
} | ||
|
||
for id in tokens_list { | ||
print(token_to_piece(token: id)) | ||
} | ||
|
||
// batch = llama_batch_init(512, 0) // done in init() | ||
batch.n_tokens = Int32(tokens_list.count) | ||
|
||
for i1 in 0...batch.n_tokens-1 { | ||
let i = Int(i1) | ||
batch.token[i] = tokens_list[i] | ||
batch.pos[i] = i1 | ||
batch.n_seq_id[Int(i)] = 1 | ||
batch.seq_id[Int(i)]![0] = 0 | ||
batch.logits[i] = 0 | ||
} | ||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true | ||
|
||
if llama_decode(context, batch) != 0 { | ||
print("llama_decode() failed") | ||
} | ||
|
||
n_cur = batch.n_tokens | ||
} | ||
|
||
func completion_loop() -> String { | ||
var new_token_id: llama_token = 0 | ||
|
||
let n_vocab = llama_n_vocab(model) | ||
let logits = llama_get_logits_ith(context, batch.n_tokens - 1) | ||
|
||
var candidates = Array<llama_token_data>() | ||
candidates.reserveCapacity(Int(n_vocab)) | ||
|
||
for token_id in 0...n_vocab { | ||
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0)) | ||
} | ||
candidates.withUnsafeMutableBufferPointer() { buffer in | ||
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false) | ||
|
||
new_token_id = llama_sample_token_greedy(context, &candidates_p) | ||
} | ||
|
||
if new_token_id == llama_token_eos(context) || n_cur == n_len { | ||
print("\n") | ||
return "" | ||
} | ||
|
||
let new_token_str = token_to_piece(token: new_token_id) | ||
print(new_token_str) | ||
// tokens_list.append(new_token_id) | ||
|
||
batch.n_tokens = 0 | ||
|
||
batch.token[Int(batch.n_tokens)] = new_token_id | ||
batch.pos[Int(batch.n_tokens)] = n_cur | ||
batch.n_seq_id[Int(batch.n_tokens)] = 1 | ||
batch.seq_id[Int(batch.n_tokens)]![0] = 0 | ||
batch.logits[Int(batch.n_tokens)] = 1 // true | ||
batch.n_tokens += 1 | ||
|
||
n_decode += 1 | ||
|
||
n_cur += 1 | ||
|
||
if llama_decode(context, batch) != 0 { | ||
print("failed to evaluate llama!") | ||
} | ||
|
||
return new_token_str | ||
} | ||
|
||
func clear() { | ||
tokens_list.removeAll() | ||
} | ||
|
||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] { | ||
let n_tokens = text.count + (add_bos ? 1 : 0) | ||
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens) | ||
let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos, false) | ||
|
||
var swiftTokens: [llama_token] = [] | ||
for i in 0..<tokenCount { | ||
swiftTokens.append(tokens[Int(i)]) | ||
} | ||
|
||
tokens.deallocate() | ||
|
||
return swiftTokens | ||
} | ||
|
||
private func token_to_piece(token: llama_token) -> String { | ||
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8) | ||
result.initialize(repeating: Int8(0), count: 8) | ||
|
||
let _ = llama_token_to_piece(model, token, result, 8) | ||
|
||
let resultStr = String(cString: result) | ||
|
||
result.deallocate() | ||
|
||
return resultStr | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
// | ||
// Use this file to import your target's public headers that you would like to expose to Swift. | ||
// | ||
|
||
#import "llama.h" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not
..<
? other C++ examples in the repo use<
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the rest looks good to me, thanks for this work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed