Skip to content

Commit

Permalink
[feat] ios demo support qwen2-vl and metal.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Sep 14, 2024
1 parent b70fb08 commit 8d8547f
Show file tree
Hide file tree
Showing 6 changed files with 578 additions and 153 deletions.
357 changes: 357 additions & 0 deletions ios/mnn-llm/mnn-llm.xcodeproj/project.pbxproj

Large diffs are not rendered by default.

Binary file not shown.
214 changes: 178 additions & 36 deletions ios/mnn-llm/mnn-llm/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,44 @@
// Created by wangzhaode on 2023/12/14.
//

import Combine
import SwiftUI
import Combine
import PhotosUI

class ChatViewModel: ObservableObject {
@Published var messages: [Message] = []
@Published var isModelLoaded = false // 模型是否加载完成
@Published var isProcessing: Bool = false // 标志表示是否有正在处理的LLM响应
@Published var isModelLoaded = false
@Published var isProcessing: Bool = false
@Published var pendingImagePath: String?
private var llm: LLMInferenceEngineWrapper?

init() {
self.messages.append(Message(id: UUID(), text: "qwen1.5-0.5b-chat 模型加载中, 请稍等 ...", isUser: false))
self.messages.append(Message(id: UUID(), text: "qwen2-vl-2b-instruct 模型加载中, 请稍等 ...", isUser: false))
llm = LLMInferenceEngineWrapper { [weak self] success in
DispatchQueue.main.async {
self?.isModelLoaded = success
var loadresult = "模型加载完毕!"
if !success {
loadresult = "模型加载失败!"
}
self?.messages.append(Message(id: UUID(), text: loadresult, isUser: false))
let loadResult = success ? "模型加载完毕!" : "模型加载失败!"
self?.messages.append(Message(id: UUID(), text: loadResult, isUser: false))
}
}
}

func sendInput(_ input: String) {
// 将用户输入作为新消息添加
let userMessage = Message(id: UUID(), text: input, isUser: true)
var combinedInput = input
if let imagePath = pendingImagePath {
combinedInput = "<img>\(imagePath)</img>" + combinedInput
}

let userMessage = Message(id: UUID(), text: input, imagePath: pendingImagePath, isUser: true)
DispatchQueue.main.async {
self.messages.append(userMessage)
self.pendingImagePath = nil
}
isProcessing = true
// 在后台线程处理耗时的输入
DispatchQueue.global(qos: .userInitiated).async {
self.llm?.processInput(input) { [weak self] output in
// 切换回主线程来更新UI
self.llm?.processInput(combinedInput) { [weak self] output in
DispatchQueue.main.async {
if (output.contains("<eop>")) {
if output.contains("<eop>") {
self?.isProcessing = false
} else {
self?.appendResponse(output)
Expand All @@ -49,44 +51,83 @@ class ChatViewModel: ObservableObject {
}
}
}


func saveImageToTemporaryFile(_ image: UIImage, completion: @escaping (String?) -> Void) {
let fileManager = FileManager.default
let tempDirectory = fileManager.temporaryDirectory
let fileURL = tempDirectory.appendingPathComponent(UUID().uuidString).appendingPathExtension("png")
if let imageData = image.pngData() {
do {
try imageData.write(to: fileURL)
completion(fileURL.path)
} catch {
print("Error saving image to file: \(error)")
completion(nil)
}
} else {
completion(nil)
}
}

private func appendResponse(_ output: String) {
if let lastMessage = messages.last, !lastMessage.isUser {
// 创建一个更新后的消息
var updatedMessage = messages[messages.count - 1]
updatedMessage.text += output
// 替换数组中的旧消息
updatedMessage.text! += output
self.messages[messages.count - 1] = updatedMessage
} else {
let newMessage = Message(id: UUID(), text: output, isUser: false)
self.messages.append(newMessage)
}
}

func deleteTemporaryFile(at path: String) {
let fileManager = FileManager.default
do {
try fileManager.removeItem(atPath: path)
} catch {
print("Error deleting temporary file: \(error)")
}
}
}


struct Message: Identifiable, Equatable {
let id: UUID
var text: String
var text: String?
var imagePath: String?
let isUser: Bool
}

struct ChatBubble: View {
let message: Message

var body: some View {
HStack {
if message.isUser {
Spacer()
}

Text(message.text)
.padding(10)
.foregroundColor(message.isUser ? .white : .black)
.background(message.isUser ? Color.blue : Color.gray.opacity(0.2))
.cornerRadius(10)
.frame(maxWidth: 400, alignment: message.isUser ? .trailing : .leading)


VStack(alignment: .leading) {
if let imagePath = message.imagePath, let image = UIImage(contentsOfFile: imagePath) {
Image(uiImage: image)
.resizable()
.aspectRatio(contentMode: .fit)
.frame(maxWidth: 300, maxHeight: 200)
.padding(10)
.background(message.isUser ? Color.blue : Color.gray.opacity(0.2))
.cornerRadius(10)
}

if let text = message.text {
Text(text)
.padding(10)
.foregroundColor(message.isUser ? .white : .black)
.background(message.isUser ? Color.blue : Color.gray.opacity(0.2))
.cornerRadius(10)
.frame(maxWidth: 400, alignment: message.isUser ? .trailing : .leading)
}
}

if !message.isUser {
Spacer()
}
Expand All @@ -98,9 +139,13 @@ struct ChatBubble: View {
struct ChatView: View {
@StateObject var viewModel = ChatViewModel()
@State private var inputText: String = ""

@State private var selectedImage: UIImage?
@State private var isImagePickerPresented: Bool = false
@State private var isImageSourcePickerPresented: Bool = false
@State private var imageSourceType: UIImagePickerController.SourceType = .photoLibrary

var body: some View {
NavigationView { // 包裹在 NavigationView 中
NavigationView {
VStack {
ScrollView {
ScrollViewReader { scrollView in
Expand All @@ -111,30 +156,127 @@ struct ChatView: View {
}
.padding(.horizontal)
.onChange(of: viewModel.messages) { _ in
scrollView.scrollTo(viewModel.messages.last?.id, anchor: .bottom)
if let lastMessageId = viewModel.messages.last?.id {
withAnimation {
scrollView.scrollTo(lastMessageId, anchor: .bottom)
}
}
}
}
}

HStack {
if let image = selectedImage {
Image(uiImage: image)
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 50, height: 50)
.padding()
.background(Color.gray.opacity(0.2))
.cornerRadius(10)

Button(action: {
selectedImage = nil
}) {
Image(systemName: "xmark.circle.fill")
.foregroundColor(.red)
}
}

TextField("Type a message...", text: $inputText)
.textFieldStyle(RoundedBorderTextFieldStyle())
.frame(minHeight: 44)

Button(action: {
viewModel.sendInput(inputText)
inputText = ""
if let image = selectedImage {
viewModel.saveImageToTemporaryFile(image) { imagePath in
if let imagePath = imagePath {
viewModel.pendingImagePath = imagePath
viewModel.sendInput(inputText)
inputText = ""
}
}
} else {
viewModel.sendInput(inputText)
inputText = ""
}
selectedImage = nil
}) {
Image(systemName: "arrow.up.circle.fill")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 44, height: 44)
}
.disabled(inputText.isEmpty || viewModel.isProcessing || !viewModel.isModelLoaded)
.disabled(inputText.isEmpty && selectedImage == nil || viewModel.isProcessing || !viewModel.isModelLoaded)

Button(action: {
isImageSourcePickerPresented = true
}) {
Image(systemName: "photo.on.rectangle.angled")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 44, height: 44)
}
}
.padding()
.actionSheet(isPresented: $isImageSourcePickerPresented) {
ActionSheet(
title: Text("Select Image Source"),
buttons: [
.default(Text("Photo Library")) {
imageSourceType = .photoLibrary
isImagePickerPresented = true
},
.default(Text("Camera")) {
imageSourceType = .camera
isImagePickerPresented = true
},
.cancel()
]
)
}
.sheet(isPresented: $isImagePickerPresented) {
ImagePicker(selectedImage: $selectedImage, sourceType: imageSourceType)
}
}
.navigationBarTitle("mnn-llm", displayMode: .inline)
}
}
}

struct ImagePicker: UIViewControllerRepresentable {
@Binding var selectedImage: UIImage?
var sourceType: UIImagePickerController.SourceType = .photoLibrary

func makeCoordinator() -> Coordinator {
Coordinator(self)
}

func makeUIViewController(context: Context) -> UIImagePickerController {
let picker = UIImagePickerController()
picker.delegate = context.coordinator
picker.sourceType = sourceType
return picker
}

func updateUIViewController(_ uiViewController: UIImagePickerController, context: Context) {}

class Coordinator: NSObject, UIImagePickerControllerDelegate, UINavigationControllerDelegate {
var parent: ImagePicker

init(_ parent: ImagePicker) {
self.parent = parent
}

func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey : Any]) {
if let image = info[.originalImage] as? UIImage {
parent.selectedImage = image
}
.navigationBarTitle("mnn-llm", displayMode: .inline) // 设置标题
picker.dismiss(animated: true)
}

func imagePickerControllerDidCancel(_ picker: UIImagePickerController) {
picker.dismiss(animated: true)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ - (instancetype)initWithCompletionHandler:(ModelLoadingCompletionHandler)complet
- (BOOL)loadModel {
if (!llm) {
std::string model_dir = GetMainBundleDirectory();
std::string config_path = model_dir + "/qwen1.5-0.5b-chat/config.json";
std::string config_path = model_dir + "/qwen2-vl-2b-instruct/config.json";
llm = Llm::createLLM(config_path);
llm->load();
}
Expand Down
9 changes: 6 additions & 3 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,9 @@ class Lvlm : public Llm {
image_mean_ = config->llm_config_.value("image_mean", image_mean_);
image_norm_ = config->llm_config_.value("image_norm", image_norm_);
}
~Lvlm() { visual_module_.reset(); }
~Lvlm() {
visual_module_.reset();
}
virtual void load() override;
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual MNN::Express::VARP embedding(const std::vector<int>& input_ids) override;
Expand Down Expand Up @@ -637,6 +639,7 @@ void Lvlm::load() {

std::vector<int> Lvlm::image_process(const std::string& image_info) {
#ifdef LLM_SUPPORT_VISION
AUTOTIME;
VARP image = nullptr;
if (image_info.substr(0, 4) == "http") {
std::regex url_regex(R"(^https?://([^/]+)(/.*))");
Expand Down Expand Up @@ -691,7 +694,7 @@ std::vector<int> Lvlm::tokenizer(const std::string& query) {
std::smatch match;
std::vector<std::string> img_infos;
std::vector<int> ids {};

image_embeddings_.clear();
while (std::regex_search(searchStart, prompt.cend(), match, img_regex)) {
// std::cout << "img match: " << match[1].str() << std::endl;
auto txt_ids = tokenizer_->encode(match.prefix().str());
Expand All @@ -704,7 +707,7 @@ std::vector<int> Lvlm::tokenizer(const std::string& query) {
auto txt_ids = tokenizer_->encode(std::string(searchStart, prompt.cend()));
ids.insert(ids.end(), txt_ids.begin(), txt_ids.end());
}
// printf("ids = ["); for (auto id : ids) printf("%d, ", id); printf("]\n");
// printf("ids (%lu) = [", ids.size()); for (auto id : ids) printf("%d, ", id); printf("]\n");
return ids;
}

Expand Down
Loading

0 comments on commit 8d8547f

Please sign in to comment.